refactor: 删除冗余常量文件并整合常量定义

This commit is contained in:
TsMask
2025-02-20 09:50:29 +08:00
parent 1b435074cb
commit a1296b6fe6
63 changed files with 1823 additions and 1748 deletions

View File

@@ -0,0 +1,206 @@
package db
import (
"fmt"
"log"
"os"
"regexp"
"time"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/gorm"
gormLog "gorm.io/gorm/logger"
"be.ems/src/framework/config"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/parse"
)
// 数据库连接实例
var dbMap = make(map[string]*gorm.DB)
type dialectInfo struct {
dialectic gorm.Dialector
logging bool
}
// 载入数据库连接
func loadDialect() map[string]dialectInfo {
dialects := make(map[string]dialectInfo)
// 读取数据源配置
datasource := config.Get("database.datasource").(map[string]any)
for key, value := range datasource {
item := value.(map[string]any)
// 数据库类型对应的数据库连接
switch item["type"] {
case "sqlite":
dsn := fmt.Sprint(item["database"])
dialects[key] = dialectInfo{
dialectic: sqlite.Open(dsn),
logging: item["logging"].(bool),
}
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
item["username"],
item["password"],
item["host"],
item["port"],
item["database"],
)
dialects[key] = dialectInfo{
dialectic: mysql.Open(dsn),
logging: item["logging"].(bool),
}
default:
logger.Warnf("%s: %v\n Not Load DB Config Type", key, item)
}
}
return dialects
}
// 载入连接日志配置
func loadLogger() gormLog.Interface {
newLogger := gormLog.New(
log.New(os.Stdout, "[GORM] ", log.LstdFlags), // 将日志输出到控制台
gormLog.Config{
SlowThreshold: time.Second, // Slow SQL 阈值
LogLevel: gormLog.Info, // 日志级别 Silent不输出任何日志
ParameterizedQueries: false, // 参数化查询SQL 用实际值带入?的执行语句
Colorful: false, // 彩色日志输出
},
)
return newLogger
}
// Connect 连接数据库实例
func Connect() {
// 遍历进行连接数据库实例
for key, info := range loadDialect() {
opts := &gorm.Config{}
// 是否需要日志输出
if info.logging {
opts.Logger = loadLogger()
}
// 创建连接
db, err := gorm.Open(info.dialectic, opts)
if err != nil {
logger.Fatalf("failed error db connect: %s", err)
}
// 获取底层 SQL 数据库连接
sqlDB, err := db.DB()
if err != nil {
logger.Fatalf("failed error underlying SQL database: %v", err)
}
// 测试数据库连接
err = sqlDB.Ping()
if err != nil {
logger.Fatalf("failed error ping database: %v", err)
}
// SetMaxIdleConns 用于设置连接池中空闲连接的最大数量。
sqlDB.SetMaxIdleConns(10)
// SetMaxOpenConns 设置打开数据库连接的最大数量。
sqlDB.SetMaxOpenConns(100)
// SetConnMaxLifetime 设置了连接可复用的最大时间。
sqlDB.SetConnMaxLifetime(time.Hour)
logger.Infof("database %s connection is successful.", key)
dbMap[key] = db
}
}
// Close 关闭数据库实例
func Close() {
for _, db := range dbMap {
sqlDB, err := db.DB()
if err != nil {
continue
}
if err := sqlDB.Close(); err != nil {
logger.Errorf("fatal error db close: %s", err)
}
}
}
// DB 获取数据源
//
// source-数据源
func DB(source string) *gorm.DB {
// 不指定时获取默认实例
if source == "" {
source = config.Get("gorm.defaultDataSourceName").(string)
}
return dbMap[source]
}
// Names 获取数据源名称列表
func Names() []string {
var names []string
for key := range dbMap {
names = append(names, key)
}
return names
}
// RawDB 原生语句查询
//
// source-数据源
// sql-预编译的SQL语句
// parameters-预编译的SQL语句参数
func RawDB(source string, sql string, parameters []any) ([]map[string]any, error) {
var rows []map[string]any
// 数据源
db := DB(source)
if db == nil {
return rows, fmt.Errorf("not database source")
}
// 使用正则表达式替换连续的空白字符为单个空格
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
// 查询结果
res := db.Raw(fmtSql, parameters...).Scan(&rows)
if res.Error != nil {
return nil, res.Error
}
return rows, nil
}
// ExecDB 原生语句执行
//
// source-数据源
// sql-预编译的SQL语句
// parameters-预编译的SQL语句参数
func ExecDB(source string, sql string, parameters []any) (int64, error) {
// 数据源
db := DB(source)
if db == nil {
return 0, fmt.Errorf("not database source")
}
// 使用正则表达式替换连续的空白字符为单个空格
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
// 执行结果
res := db.Exec(fmtSql, parameters...)
if res.Error != nil {
return 0, res.Error
}
return res.RowsAffected, nil
}
// PageNumSize 分页页码记录数
//
// pageNum-页码
// pageSize-记录数
func PageNumSize(pageNum, pageSize any) (int, int) {
// 记录起始索引
num := parse.Number(pageNum)
if num < 1 {
num = 1
}
// 显示记录数
size := parse.Number(pageSize)
if size < 0 {
size = 10
}
return int(num - 1), int(size)
}

View File

@@ -0,0 +1,79 @@
package redis
import (
"context"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
// ConnRedis 连接redis对象
type ConnRedis struct {
Addr string `json:"addr"` // 地址
Port int64 `json:"port"` // 端口
User string `json:"user"` // 用户名
Password string `json:"password"` // 认证密码
Database int `json:"database"` // 数据库名称
DialTimeOut time.Duration `json:"dialTimeOut"` // 连接超时断开
Client *redis.Client `json:"client"`
}
// NewClient 创建Redis客户端
func (c *ConnRedis) NewClient() (*ConnRedis, error) {
// IPV6地址协议
if strings.Contains(c.Addr, ":") {
c.Addr = fmt.Sprintf("[%s]", c.Addr)
}
addr := fmt.Sprintf("%s:%d", c.Addr, c.Port)
// 默认等待5s
if c.DialTimeOut == 0 {
c.DialTimeOut = 5 * time.Second
}
// 连接
rdb := redis.NewClient(&redis.Options{
Addr: addr,
// Username: c.User,
Password: c.Password,
DB: c.Database,
DialTimeout: c.DialTimeOut,
})
// 测试数据库连接
if _, err := rdb.Ping(context.Background()).Result(); err != nil {
return nil, err
}
c.Client = rdb
return c, nil
}
// Close 关闭当前Redis客户端
func (c *ConnRedis) Close() {
if c.Client != nil {
c.Client.Close()
}
}
// RunCMD 执行单次命令 "GET key"
func (c *ConnRedis) RunCMD(cmd string) (any, error) {
if c.Client == nil {
return "", fmt.Errorf("redis client not connected")
}
// 写入命令
cmdArr := strings.Fields(cmd)
if len(cmdArr) == 0 {
return "", fmt.Errorf("redis command is empty")
}
conn := *c.Client
args := make([]any, 0)
for _, v := range cmdArr {
args = append(args, v)
}
return conn.Do(context.Background(), args...).Result()
}

View File

@@ -0,0 +1,139 @@
package redis
import (
"context"
"errors"
"fmt"
"sync"
"be.ems/src/framework/logger"
"github.com/redis/go-redis/v9"
)
// 连接Redis实例
func ConnectPush(source string, rdb *redis.Client) {
if rdb == nil {
delete(rdbMap, source)
return
}
rdbMap[source] = rdb
}
// 批量获得缓存数据 [key]result
func GetHashBatch(source string, keys []string) (map[string]map[string]string, error) {
result := make(map[string]map[string]string, 0)
if len(keys) == 0 {
return result, fmt.Errorf("not keys")
}
// 数据源
rdb := RDB(source)
if rdb == nil {
return result, fmt.Errorf("redis not client")
}
// 创建一个有限的并发控制信号通道
sem := make(chan struct{}, 10)
var wg sync.WaitGroup
var mt sync.Mutex
batchSize := 1000
total := len(keys)
if total < batchSize {
batchSize = total
}
for i := 0; i < total; i += batchSize {
wg.Add(1)
go func(start int) {
ctx := context.Background()
// 并发控制,限制同时执行的 Goroutine 数量
sem <- struct{}{}
defer func() {
<-sem
ctx.Done()
wg.Done()
}()
// 检查索引是否越界
end := start + batchSize
if end > total {
end = total
}
pipe := rdb.Pipeline()
for _, key := range keys[start:end] {
pipe.HGetAll(ctx, key)
}
cmds, err := pipe.Exec(ctx)
if err != nil {
logger.Errorf("Failed to get hash batch exec err: %v", err)
return
}
// 将结果添加到 result map 并发访问
mt.Lock()
defer mt.Unlock()
// 处理命令结果
for _, cmd := range cmds {
if cmd.Err() != nil {
logger.Errorf("Failed to get hash batch cmds err: %v", cmd.Err())
continue
}
// 将结果转换为 *redis.StringStringMapCmd 类型
rcmd, ok := cmd.(*redis.MapStringStringCmd)
if !ok {
logger.Errorf("Failed to get hash batch type err: %v", cmd.Err())
continue
}
key := "-"
args := rcmd.Args()
if len(args) > 0 {
key = fmt.Sprint(args[1])
}
result[key] = rcmd.Val()
}
}(i)
}
wg.Wait()
return result, nil
}
// GetHash 获得缓存数据
func GetHash(source, key, field string) (string, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return "", fmt.Errorf("redis not client")
}
ctx := context.Background()
v, err := rdb.HGet(ctx, key, field).Result()
if errors.Is(err, redis.Nil) {
return "", fmt.Errorf("no key field")
}
if err != nil {
return "", err
}
return v, nil
}
// SetHash 设置缓存数据
func SetHash(source, key string, value map[string]any) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
err := rdb.HSet(ctx, key, value).Err()
if err != nil {
logger.Errorf("redis HSet err %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,346 @@
package redis
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"be.ems/src/framework/config"
"be.ems/src/framework/logger"
)
// Redis连接实例
var rdbMap = make(map[string]*redis.Client)
// Connect 连接Redis实例
func Connect() {
ctx := context.Background()
// 读取数据源配置
datasource := config.Get("redis.dataSource").(map[string]any)
for k, v := range datasource {
client := v.(map[string]any)
// 创建连接
address := fmt.Sprintf("%s:%d", client["host"], client["port"])
rdb := redis.NewClient(&redis.Options{
Addr: address,
Password: client["password"].(string),
DB: client["db"].(int),
})
// 测试数据库连接
pong, err := rdb.Ping(ctx).Result()
if err != nil {
logger.Fatalf("Ping redis %s is %v", k, err)
}
logger.Infof("redis %s %d %s connection is successful.", k, client["db"].(int), pong)
rdbMap[k] = rdb
}
}
// Close 关闭Redis实例
func Close() {
for _, rdb := range rdbMap {
if err := rdb.Close(); err != nil {
logger.Errorf("redis db close: %s", err)
}
}
}
// RDB 获取实例
func RDB(source string) *redis.Client {
// 不指定时获取默认实例
if source == "" {
source = config.Get("redis.defaultDataSourceName").(string)
}
return rdbMap[source]
}
// Info 获取redis服务信息
func Info(source string) map[string]map[string]string {
infoObj := make(map[string]map[string]string)
// 数据源
rdb := RDB(source)
if rdb == nil {
return infoObj
}
ctx := context.Background()
info, err := rdb.Info(ctx).Result()
if err != nil {
return infoObj
}
lines := strings.Split(info, "\r\n")
label := ""
for _, line := range lines {
if strings.Contains(line, "#") {
label = strings.Fields(line)[len(strings.Fields(line))-1]
label = strings.ToLower(label)
infoObj[label] = make(map[string]string)
continue
}
kvArr := strings.Split(line, ":")
if len(kvArr) >= 2 {
key := strings.TrimSpace(kvArr[0])
value := strings.TrimSpace(kvArr[len(kvArr)-1])
infoObj[label][key] = value
}
}
return infoObj
}
// KeySize 获取redis当前连接可用键Key总数信息
func KeySize(source string) int64 {
// 数据源
rdb := RDB(source)
if rdb == nil {
return 0
}
ctx := context.Background()
size, err := rdb.DBSize(ctx).Result()
if err != nil {
return 0
}
return size
}
// CommandStats 获取redis命令状态信息
func CommandStats(source string) []map[string]string {
statsObjArr := make([]map[string]string, 0)
// 数据源
rdb := RDB(source)
if rdb == nil {
return statsObjArr
}
ctx := context.Background()
commandstats, err := rdb.Info(ctx, "commandstats").Result()
if err != nil {
return statsObjArr
}
lines := strings.Split(commandstats, "\r\n")
for _, line := range lines {
if !strings.HasPrefix(line, "cmdstat_") {
continue
}
kvArr := strings.Split(line, ":")
key := kvArr[0]
valueStr := kvArr[len(kvArr)-1]
statsObj := make(map[string]string)
statsObj["name"] = key[8:]
statsObj["value"] = valueStr[6:strings.Index(valueStr, ",usec=")]
statsObjArr = append(statsObjArr, statsObj)
}
return statsObjArr
}
// GetExpire 获取键的剩余有效时间(秒)
func GetExpire(source string, key string) (int64, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return 0, fmt.Errorf("redis not client")
}
ctx := context.Background()
ttl, err := rdb.TTL(ctx, key).Result()
if err != nil {
return 0, err
}
return int64(ttl.Seconds()), nil
}
// GetKeys 获得缓存数据的key列表
func GetKeys(source string, pattern string) ([]string, error) {
keys := make([]string, 0)
// 数据源
rdb := RDB(source)
if rdb == nil {
return keys, fmt.Errorf("redis not client")
}
// 游标
var cursor uint64 = 0
var count int64 = 100
ctx := context.Background()
// 循环遍历获取匹配的键
for {
// 使用 SCAN 命令获取匹配的键
batchKeys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, count).Result()
if err != nil {
logger.Errorf("failed to scan keys: %v", err)
return keys, err
}
cursor = nextCursor
keys = append(keys, batchKeys...)
// 当 cursor 为 0表示遍历完成
if cursor == 0 {
break
}
}
return keys, nil
}
// GetBatch 批量获得缓存数据
func GetBatch(source string, keys []string) ([]any, error) {
result := make([]any, 0)
if len(keys) == 0 {
return result, fmt.Errorf("not keys")
}
// 数据源
rdb := RDB(source)
if rdb == nil {
return result, fmt.Errorf("redis not client")
}
// 获取缓存数据
v, err := rdb.MGet(context.Background(), keys...).Result()
if err != nil || errors.Is(err, redis.Nil) {
logger.Errorf("failed to get batch data: %v", err)
return result, err
}
return v, nil
}
// Get 获得缓存数据
func Get(source, key string) (string, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return "", fmt.Errorf("redis not client")
}
ctx := context.Background()
v, err := rdb.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
return "", fmt.Errorf("no keys")
}
if err != nil {
return "", err
}
return v, nil
}
// Has 判断是否存在
func Has(source string, keys ...string) (int64, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return 0, fmt.Errorf("redis not client")
}
ctx := context.Background()
exists, err := rdb.Exists(ctx, keys...).Result()
if err != nil {
return 0, err
}
return exists, nil
}
// Set 设置缓存数据
func Set(source, key string, value any) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
err := rdb.Set(ctx, key, value, 0).Err()
if err != nil {
logger.Errorf("redis Set err %v", err)
return err
}
return nil
}
// SetByExpire 设置缓存数据与过期时间
func SetByExpire(source, key string, value any, expiration time.Duration) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
err := rdb.Set(ctx, key, value, expiration).Err()
if err != nil {
logger.Errorf("redis SetByExpire err %v", err)
return err
}
return nil
}
// Del 删除单个
func Del(source string, key string) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
if err := rdb.Del(ctx, key).Err(); err != nil {
logger.Errorf("redis Del err %v", err)
return err
}
return nil
}
// DelKeys 删除多个
func DelKeys(source string, keys []string) error {
if len(keys) == 0 {
return fmt.Errorf("no keys")
}
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
logger.Errorf("redis DelKeys err %v", err)
return err
}
return nil
}
// RateLimit 限流查询并记录
func RateLimit(source, limitKey string, time, count int64) (int64, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return 0, fmt.Errorf("redis not client")
}
ctx := context.Background()
result, err := rateLimitCommand.Run(ctx, rdb, []string{limitKey}, time, count).Result()
if err != nil {
logger.Errorf("redis lua script err %v", err)
return 0, err
}
return result.(int64), err
}
// 声明定义限流脚本命令
var rateLimitCommand = redis.NewScript(`
local key = KEYS[1]
local time = tonumber(ARGV[1])
local count = tonumber(ARGV[2])
local current = redis.call('get', key);
if current and tonumber(current) >= count then
return tonumber(current);
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current);`)