package db import ( "fmt" "log" "os" "regexp" "strings" "time" "github.com/glebarez/sqlite" "gorm.io/driver/mysql" "gorm.io/gorm" gormLog "gorm.io/gorm/logger" "be.ems/src/framework/config" "be.ems/src/framework/logger" "be.ems/src/framework/utils/parse" ) // 数据库连接实例 var dbMap = make(map[string]*gorm.DB) type dialectInfo struct { dialectic gorm.Dialector logging bool } // 载入数据库连接 func loadDialect() map[string]dialectInfo { dialects := make(map[string]dialectInfo) // 读取数据源配置 datasource := config.Get("database.datasource").(map[string]any) for key, value := range datasource { item := value.(map[string]any) // 数据库类型对应的数据库连接 switch item["type"] { case "sqlite": pragmas := []string{ "cache=shared", // Enable shared cache "_pragma=journal_mode(WAL)", // Enable WAL mode "_pragma=busy_timeout(5000)", // Set busy timeout 抛出 SQLITE_BUSY 错误 默认0立即,设置等待 5 秒(5000 毫秒) "_pragma=synchronous(NORMAL)", // Set synchronous mode "_pragma=wal_autocheckpoint(5000)", // Set WAL auto-checkpoint threshold 设置较大的自动检查点阈值 "_pragma=mmap_size(67108864)", // Set mmap size 内存映射 I/O 最大字节数为 64MB } dsn := fmt.Sprintf("%s?%s", item["database"], strings.Join(pragmas, "&")) dialects[key] = dialectInfo{ dialectic: sqlite.Open(dsn), logging: parse.Boolean(item["logging"]), } case "mysql": dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", item["username"], item["password"], item["host"], item["port"], item["database"], ) dialects[key] = dialectInfo{ dialectic: mysql.Open(dsn), logging: parse.Boolean(item["logging"]), } default: logger.Warnf("%s: %v\n Not Load DB Config Type", key, item) } } return dialects } // 载入连接日志配置 func loadLogger() gormLog.Interface { newLogger := gormLog.New( log.New(os.Stdout, "[GORM] ", log.LstdFlags), // 将日志输出到控制台 gormLog.Config{ SlowThreshold: time.Second, // Slow SQL 阈值 LogLevel: gormLog.Info, // 日志级别 Silent不输出任何日志 ParameterizedQueries: false, // 参数化查询SQL 用实际值带入?的执行语句 Colorful: false, // 彩色日志输出 }, ) return newLogger } // Connect 连接数据库实例 func Connect() { // 遍历进行连接数据库实例 for key, info := range loadDialect() { opts := &gorm.Config{ Logger: gormLog.Discard, } // 是否需要日志输出 if info.logging { opts.Logger = loadLogger() } // 创建连接 db, err := gorm.Open(info.dialectic, opts) if err != nil { logger.Errorf("failed error db connect: %s", err) continue } // 获取底层 SQL 数据库连接 sqlDB, err := db.DB() if err != nil { logger.Fatalf("failed error underlying SQL database: %v", err) } // 测试数据库连接 err = sqlDB.Ping() if err != nil { logger.Fatalf("failed error ping database: %v", err) } // SetMaxIdleConns 用于设置连接池中空闲连接的最大数量。 sqlDB.SetMaxIdleConns(10) // SetMaxOpenConns 设置打开数据库连接的最大数量。 sqlDB.SetMaxOpenConns(100) // SetConnMaxLifetime 设置了连接可复用的最大时间。 sqlDB.SetConnMaxLifetime(time.Hour) logger.Infof("database %s connection is successful.", key) dbMap[key] = db } } // Close 关闭数据库实例 func Close() { for _, db := range dbMap { sqlDB, err := db.DB() if err != nil { continue } if err := sqlDB.Close(); err != nil { logger.Errorf("fatal error db close: %s", err) } } } // DB 获取数据源 // // source-数据源 func DB(source string) *gorm.DB { // 不指定时获取默认实例 if source == "" { source = config.Get("database.defaultDataSourceName").(string) } db := dbMap[source] if db == nil { logger.Errorf("not database source: %s", source) return nil } return db } // Names 获取数据源名称列表 func Names() []string { var names []string for key := range dbMap { names = append(names, key) } return names } // RawDB 原生语句查询 // // source-数据源 // sql-预编译的SQL语句 // parameters-预编译的SQL语句参数 func RawDB(source string, sql string, parameters []any) ([]map[string]any, error) { var rows []map[string]any // 数据源 db := DB(source) if db == nil { return rows, fmt.Errorf("not database source") } // 使用正则表达式替换连续的空白字符为单个空格 fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ") // 查询结果 res := db.Raw(fmtSql, parameters...).Scan(&rows) if res.Error != nil { return nil, res.Error } return rows, nil } // ExecDB 原生语句执行 // // source-数据源 // sql-预编译的SQL语句 // parameters-预编译的SQL语句参数 func ExecDB(source string, sql string, parameters []any) (int64, error) { // 数据源 db := DB(source) if db == nil { return 0, fmt.Errorf("not database source") } // 使用正则表达式替换连续的空白字符为单个空格 fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ") // 执行结果 res := db.Exec(fmtSql, parameters...) if res.Error != nil { return 0, res.Error } return res.RowsAffected, nil } // PageNumSize 分页页码记录数 // // pageNum-页码 // pageSize-记录数 func PageNumSize(pageNum, pageSize any) (int, int) { // 记录起始索引 num := parse.Number(pageNum) if num < 1 { num = 1 } // 显示记录数 size := parse.Number(pageSize) if size < 0 { size = 10 } return int(num - 1), int(size) }