package dborm import ( "database/sql" "fmt" "log" "os" "regexp" "time" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) // 数据库连接实例 var dbgEngine *gorm.DB // 载入连接日志配置 func loadLogger() logger.Interface { newLogger := logger.New( log.New(os.Stdout, "[GORM] ", log.LstdFlags), // 将日志输出到控制台 logger.Config{ SlowThreshold: time.Second, // Slow SQL 阈值 LogLevel: logger.Info, // 日志级别 Silent不输出任何日志 ParameterizedQueries: false, // 参数化查询SQL 用实际值带入?的执行语句 Colorful: false, // 彩色日志输出 }, ) return newLogger } // 连接数据库实例 func InitGormConnect(dbType, dbUser, dbPassword, dbHost, dbPort, dbName, dbParam, dbLogging any) error { var dialector gorm.Dialector switch dbType { case "mysql": dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s", dbUser, dbPassword, dbHost, dbPort, dbName, dbParam, ) dialector = mysql.Open(dsn) default: err := fmt.Errorf("invalid type: %s", dbType) return err } opts := &gorm.Config{} // 是否需要日志输出 if dbLogging.(bool) { opts.Logger = loadLogger() } // 创建连接 db, err := gorm.Open(dialector, opts) if err != nil { log.Fatalf("failed to open: %s", err) return err } // 获取底层 SQL 数据库连接 sqlDB, err := db.DB() if err != nil { log.Fatalf("failed to connect DB pool: %v", err) return err } // 测试数据库连接 err = sqlDB.Ping() if err != nil { log.Fatalf("failed to ping database: %v", err) return err } dbgEngine = db return nil } // 关闭数据库实例 func Close() { sqlDB, err := dbgEngine.DB() if err != nil { log.Fatalf("failed to connect pool: %s", err) } if err := sqlDB.Close(); err != nil { log.Fatalf("failed to close: %s", err) } } // default gorm DB func DefaultDB() *gorm.DB { return dbgEngine } // get sql DB func GCoreDB() (*sql.DB, error) { return dbgEngine.DB() } // RawSQL 原生查询语句 func RawSQL(sql string, parameters []any) ([]map[string]any, error) { // 数据源 db := DefaultDB() // 使用正则表达式替换连续的空白字符为单个空格 fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ") // logger.Infof("sql=> %v", fmtSql) // logger.Infof("parameters=> %v", parameters) // 查询结果 var rows []map[string]any res := db.Raw(fmtSql, parameters...).Scan(&rows) if res.Error != nil { return nil, res.Error } return rows, nil } // ExecSQL 原生执行语句 func ExecSQL(sql string, parameters []any) (int64, error) { // 数据源 db := DefaultDB() // 使用正则表达式替换连续的空白字符为单个空格 fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ") // 执行结果 res := db.Exec(fmtSql, parameters...) if res.Error != nil { return 0, res.Error } return res.RowsAffected, nil } func CloneTable(srcTable, dstTable string) error { // 获取表 A 的结构信息 var columns []gorm.ColumnType dbMigrator := dbgEngine.Migrator() columns, err := dbMigrator.ColumnTypes(srcTable) if err != nil { return fmt.Errorf("failed to ColumnTypes, %v", err) } // 创建表 destination table err = dbMigrator.CreateTable(dstTable) if err != nil { return fmt.Errorf("failed to CreateTable, %v", err) } // 复制表 src 的字段到表 dst for _, column := range columns { err = dbMigrator.AddColumn(dstTable, column.Name()) if err != nil { return fmt.Errorf("failed to AddColumn, %v", err) } } // 复制表 src 的主键和索引到表 dst err = dbMigrator.CreateConstraint(dstTable, "PRIMARY") if err != nil { return fmt.Errorf("failed to AddColumn, %v", err) } err = dbMigrator.CreateConstraint(dstTable, "INDEX") if err != nil { return fmt.Errorf("failed to AddColumn, %v", err) } return nil }