refactor: 删除冗余常量文件并整合常量定义
This commit is contained in:
206
src/framework/database/db/db.go
Normal file
206
src/framework/database/db/db.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
gormLog "gorm.io/gorm/logger"
|
||||
|
||||
"be.ems/src/framework/config"
|
||||
"be.ems/src/framework/logger"
|
||||
"be.ems/src/framework/utils/parse"
|
||||
)
|
||||
|
||||
// 数据库连接实例
|
||||
var dbMap = make(map[string]*gorm.DB)
|
||||
|
||||
type dialectInfo struct {
|
||||
dialectic gorm.Dialector
|
||||
logging bool
|
||||
}
|
||||
|
||||
// 载入数据库连接
|
||||
func loadDialect() map[string]dialectInfo {
|
||||
dialects := make(map[string]dialectInfo)
|
||||
|
||||
// 读取数据源配置
|
||||
datasource := config.Get("database.datasource").(map[string]any)
|
||||
for key, value := range datasource {
|
||||
item := value.(map[string]any)
|
||||
// 数据库类型对应的数据库连接
|
||||
switch item["type"] {
|
||||
case "sqlite":
|
||||
dsn := fmt.Sprint(item["database"])
|
||||
dialects[key] = dialectInfo{
|
||||
dialectic: sqlite.Open(dsn),
|
||||
logging: item["logging"].(bool),
|
||||
}
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
item["username"],
|
||||
item["password"],
|
||||
item["host"],
|
||||
item["port"],
|
||||
item["database"],
|
||||
)
|
||||
dialects[key] = dialectInfo{
|
||||
dialectic: mysql.Open(dsn),
|
||||
logging: item["logging"].(bool),
|
||||
}
|
||||
default:
|
||||
logger.Warnf("%s: %v\n Not Load DB Config Type", key, item)
|
||||
}
|
||||
}
|
||||
|
||||
return dialects
|
||||
}
|
||||
|
||||
// 载入连接日志配置
|
||||
func loadLogger() gormLog.Interface {
|
||||
newLogger := gormLog.New(
|
||||
log.New(os.Stdout, "[GORM] ", log.LstdFlags), // 将日志输出到控制台
|
||||
gormLog.Config{
|
||||
SlowThreshold: time.Second, // Slow SQL 阈值
|
||||
LogLevel: gormLog.Info, // 日志级别 Silent不输出任何日志
|
||||
ParameterizedQueries: false, // 参数化查询SQL 用实际值带入?的执行语句
|
||||
Colorful: false, // 彩色日志输出
|
||||
},
|
||||
)
|
||||
return newLogger
|
||||
}
|
||||
|
||||
// Connect 连接数据库实例
|
||||
func Connect() {
|
||||
// 遍历进行连接数据库实例
|
||||
for key, info := range loadDialect() {
|
||||
opts := &gorm.Config{}
|
||||
// 是否需要日志输出
|
||||
if info.logging {
|
||||
opts.Logger = loadLogger()
|
||||
}
|
||||
// 创建连接
|
||||
db, err := gorm.Open(info.dialectic, opts)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed error db connect: %s", err)
|
||||
}
|
||||
// 获取底层 SQL 数据库连接
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logger.Fatalf("failed error underlying SQL database: %v", err)
|
||||
}
|
||||
// 测试数据库连接
|
||||
err = sqlDB.Ping()
|
||||
if err != nil {
|
||||
logger.Fatalf("failed error ping database: %v", err)
|
||||
}
|
||||
// SetMaxIdleConns 用于设置连接池中空闲连接的最大数量。
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
// SetMaxOpenConns 设置打开数据库连接的最大数量。
|
||||
sqlDB.SetMaxOpenConns(100)
|
||||
// SetConnMaxLifetime 设置了连接可复用的最大时间。
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
logger.Infof("database %s connection is successful.", key)
|
||||
dbMap[key] = db
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭数据库实例
|
||||
func Close() {
|
||||
for _, db := range dbMap {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
logger.Errorf("fatal error db close: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DB 获取数据源
|
||||
//
|
||||
// source-数据源
|
||||
func DB(source string) *gorm.DB {
|
||||
// 不指定时获取默认实例
|
||||
if source == "" {
|
||||
source = config.Get("gorm.defaultDataSourceName").(string)
|
||||
}
|
||||
return dbMap[source]
|
||||
}
|
||||
|
||||
// Names 获取数据源名称列表
|
||||
func Names() []string {
|
||||
var names []string
|
||||
for key := range dbMap {
|
||||
names = append(names, key)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RawDB 原生语句查询
|
||||
//
|
||||
// source-数据源
|
||||
// sql-预编译的SQL语句
|
||||
// parameters-预编译的SQL语句参数
|
||||
func RawDB(source string, sql string, parameters []any) ([]map[string]any, error) {
|
||||
var rows []map[string]any
|
||||
// 数据源
|
||||
db := DB(source)
|
||||
if db == nil {
|
||||
return rows, fmt.Errorf("not database source")
|
||||
}
|
||||
// 使用正则表达式替换连续的空白字符为单个空格
|
||||
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
|
||||
// 查询结果
|
||||
res := db.Raw(fmtSql, parameters...).Scan(&rows)
|
||||
if res.Error != nil {
|
||||
return nil, res.Error
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// ExecDB 原生语句执行
|
||||
//
|
||||
// source-数据源
|
||||
// sql-预编译的SQL语句
|
||||
// parameters-预编译的SQL语句参数
|
||||
func ExecDB(source string, sql string, parameters []any) (int64, error) {
|
||||
// 数据源
|
||||
db := DB(source)
|
||||
if db == nil {
|
||||
return 0, fmt.Errorf("not database source")
|
||||
}
|
||||
// 使用正则表达式替换连续的空白字符为单个空格
|
||||
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
|
||||
// 执行结果
|
||||
res := db.Exec(fmtSql, parameters...)
|
||||
if res.Error != nil {
|
||||
return 0, res.Error
|
||||
}
|
||||
return res.RowsAffected, nil
|
||||
}
|
||||
|
||||
// PageNumSize 分页页码记录数
|
||||
//
|
||||
// pageNum-页码
|
||||
// pageSize-记录数
|
||||
func PageNumSize(pageNum, pageSize any) (int, int) {
|
||||
// 记录起始索引
|
||||
num := parse.Number(pageNum)
|
||||
if num < 1 {
|
||||
num = 1
|
||||
}
|
||||
|
||||
// 显示记录数
|
||||
size := parse.Number(pageSize)
|
||||
if size < 0 {
|
||||
size = 10
|
||||
}
|
||||
return int(num - 1), int(size)
|
||||
}
|
||||
79
src/framework/database/redis/conn.go
Normal file
79
src/framework/database/redis/conn.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ConnRedis 连接redis对象
|
||||
type ConnRedis struct {
|
||||
Addr string `json:"addr"` // 地址
|
||||
Port int64 `json:"port"` // 端口
|
||||
User string `json:"user"` // 用户名
|
||||
Password string `json:"password"` // 认证密码
|
||||
Database int `json:"database"` // 数据库名称
|
||||
|
||||
DialTimeOut time.Duration `json:"dialTimeOut"` // 连接超时断开
|
||||
|
||||
Client *redis.Client `json:"client"`
|
||||
}
|
||||
|
||||
// NewClient 创建Redis客户端
|
||||
func (c *ConnRedis) NewClient() (*ConnRedis, error) {
|
||||
// IPV6地址协议
|
||||
if strings.Contains(c.Addr, ":") {
|
||||
c.Addr = fmt.Sprintf("[%s]", c.Addr)
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", c.Addr, c.Port)
|
||||
|
||||
// 默认等待5s
|
||||
if c.DialTimeOut == 0 {
|
||||
c.DialTimeOut = 5 * time.Second
|
||||
}
|
||||
|
||||
// 连接
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
// Username: c.User,
|
||||
Password: c.Password,
|
||||
DB: c.Database,
|
||||
DialTimeout: c.DialTimeOut,
|
||||
})
|
||||
|
||||
// 测试数据库连接
|
||||
if _, err := rdb.Ping(context.Background()).Result(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.Client = rdb
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close 关闭当前Redis客户端
|
||||
func (c *ConnRedis) Close() {
|
||||
if c.Client != nil {
|
||||
c.Client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// RunCMD 执行单次命令 "GET key"
|
||||
func (c *ConnRedis) RunCMD(cmd string) (any, error) {
|
||||
if c.Client == nil {
|
||||
return "", fmt.Errorf("redis client not connected")
|
||||
}
|
||||
// 写入命令
|
||||
cmdArr := strings.Fields(cmd)
|
||||
if len(cmdArr) == 0 {
|
||||
return "", fmt.Errorf("redis command is empty")
|
||||
}
|
||||
conn := *c.Client
|
||||
args := make([]any, 0)
|
||||
for _, v := range cmdArr {
|
||||
args = append(args, v)
|
||||
}
|
||||
return conn.Do(context.Background(), args...).Result()
|
||||
}
|
||||
139
src/framework/database/redis/expand.go
Normal file
139
src/framework/database/redis/expand.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"be.ems/src/framework/logger"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 连接Redis实例
|
||||
func ConnectPush(source string, rdb *redis.Client) {
|
||||
if rdb == nil {
|
||||
delete(rdbMap, source)
|
||||
return
|
||||
}
|
||||
rdbMap[source] = rdb
|
||||
}
|
||||
|
||||
// 批量获得缓存数据 [key]result
|
||||
func GetHashBatch(source string, keys []string) (map[string]map[string]string, error) {
|
||||
result := make(map[string]map[string]string, 0)
|
||||
if len(keys) == 0 {
|
||||
return result, fmt.Errorf("not keys")
|
||||
}
|
||||
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return result, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
// 创建一个有限的并发控制信号通道
|
||||
sem := make(chan struct{}, 10)
|
||||
var wg sync.WaitGroup
|
||||
var mt sync.Mutex
|
||||
batchSize := 1000
|
||||
total := len(keys)
|
||||
if total < batchSize {
|
||||
batchSize = total
|
||||
}
|
||||
|
||||
for i := 0; i < total; i += batchSize {
|
||||
wg.Add(1)
|
||||
go func(start int) {
|
||||
ctx := context.Background()
|
||||
// 并发控制,限制同时执行的 Goroutine 数量
|
||||
sem <- struct{}{}
|
||||
defer func() {
|
||||
<-sem
|
||||
ctx.Done()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// 检查索引是否越界
|
||||
end := start + batchSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
pipe := rdb.Pipeline()
|
||||
for _, key := range keys[start:end] {
|
||||
pipe.HGetAll(ctx, key)
|
||||
}
|
||||
|
||||
cmds, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get hash batch exec err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 将结果添加到 result map 并发访问
|
||||
mt.Lock()
|
||||
defer mt.Unlock()
|
||||
|
||||
// 处理命令结果
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Err() != nil {
|
||||
logger.Errorf("Failed to get hash batch cmds err: %v", cmd.Err())
|
||||
continue
|
||||
}
|
||||
// 将结果转换为 *redis.StringStringMapCmd 类型
|
||||
rcmd, ok := cmd.(*redis.MapStringStringCmd)
|
||||
if !ok {
|
||||
logger.Errorf("Failed to get hash batch type err: %v", cmd.Err())
|
||||
continue
|
||||
}
|
||||
|
||||
key := "-"
|
||||
args := rcmd.Args()
|
||||
if len(args) > 0 {
|
||||
key = fmt.Sprint(args[1])
|
||||
}
|
||||
|
||||
result[key] = rcmd.Val()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetHash 获得缓存数据
|
||||
func GetHash(source, key, field string) (string, error) {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return "", fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
v, err := rdb.HGet(ctx, key, field).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", fmt.Errorf("no key field")
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// SetHash 设置缓存数据
|
||||
func SetHash(source, key string, value map[string]any) error {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := rdb.HSet(ctx, key, value).Err()
|
||||
if err != nil {
|
||||
logger.Errorf("redis HSet err %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
346
src/framework/database/redis/redis.go
Normal file
346
src/framework/database/redis/redis.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"be.ems/src/framework/config"
|
||||
"be.ems/src/framework/logger"
|
||||
)
|
||||
|
||||
// Redis连接实例
|
||||
var rdbMap = make(map[string]*redis.Client)
|
||||
|
||||
// Connect 连接Redis实例
|
||||
func Connect() {
|
||||
ctx := context.Background()
|
||||
// 读取数据源配置
|
||||
datasource := config.Get("redis.dataSource").(map[string]any)
|
||||
for k, v := range datasource {
|
||||
client := v.(map[string]any)
|
||||
// 创建连接
|
||||
address := fmt.Sprintf("%s:%d", client["host"], client["port"])
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: address,
|
||||
Password: client["password"].(string),
|
||||
DB: client["db"].(int),
|
||||
})
|
||||
// 测试数据库连接
|
||||
pong, err := rdb.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
logger.Fatalf("Ping redis %s is %v", k, err)
|
||||
}
|
||||
logger.Infof("redis %s %d %s connection is successful.", k, client["db"].(int), pong)
|
||||
rdbMap[k] = rdb
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭Redis实例
|
||||
func Close() {
|
||||
for _, rdb := range rdbMap {
|
||||
if err := rdb.Close(); err != nil {
|
||||
logger.Errorf("redis db close: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RDB 获取实例
|
||||
func RDB(source string) *redis.Client {
|
||||
// 不指定时获取默认实例
|
||||
if source == "" {
|
||||
source = config.Get("redis.defaultDataSourceName").(string)
|
||||
}
|
||||
return rdbMap[source]
|
||||
}
|
||||
|
||||
// Info 获取redis服务信息
|
||||
func Info(source string) map[string]map[string]string {
|
||||
infoObj := make(map[string]map[string]string)
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return infoObj
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
info, err := rdb.Info(ctx).Result()
|
||||
if err != nil {
|
||||
return infoObj
|
||||
}
|
||||
|
||||
lines := strings.Split(info, "\r\n")
|
||||
label := ""
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "#") {
|
||||
label = strings.Fields(line)[len(strings.Fields(line))-1]
|
||||
label = strings.ToLower(label)
|
||||
infoObj[label] = make(map[string]string)
|
||||
continue
|
||||
}
|
||||
kvArr := strings.Split(line, ":")
|
||||
if len(kvArr) >= 2 {
|
||||
key := strings.TrimSpace(kvArr[0])
|
||||
value := strings.TrimSpace(kvArr[len(kvArr)-1])
|
||||
infoObj[label][key] = value
|
||||
}
|
||||
}
|
||||
return infoObj
|
||||
}
|
||||
|
||||
// KeySize 获取redis当前连接可用键Key总数信息
|
||||
func KeySize(source string) int64 {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
size, err := rdb.DBSize(ctx).Result()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// CommandStats 获取redis命令状态信息
|
||||
func CommandStats(source string) []map[string]string {
|
||||
statsObjArr := make([]map[string]string, 0)
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return statsObjArr
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
commandstats, err := rdb.Info(ctx, "commandstats").Result()
|
||||
if err != nil {
|
||||
return statsObjArr
|
||||
}
|
||||
|
||||
lines := strings.Split(commandstats, "\r\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "cmdstat_") {
|
||||
continue
|
||||
}
|
||||
kvArr := strings.Split(line, ":")
|
||||
key := kvArr[0]
|
||||
valueStr := kvArr[len(kvArr)-1]
|
||||
statsObj := make(map[string]string)
|
||||
statsObj["name"] = key[8:]
|
||||
statsObj["value"] = valueStr[6:strings.Index(valueStr, ",usec=")]
|
||||
statsObjArr = append(statsObjArr, statsObj)
|
||||
}
|
||||
return statsObjArr
|
||||
}
|
||||
|
||||
// GetExpire 获取键的剩余有效时间(秒)
|
||||
func GetExpire(source string, key string) (int64, error) {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return 0, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(ttl.Seconds()), nil
|
||||
}
|
||||
|
||||
// GetKeys 获得缓存数据的key列表
|
||||
func GetKeys(source string, pattern string) ([]string, error) {
|
||||
keys := make([]string, 0)
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return keys, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
// 游标
|
||||
var cursor uint64 = 0
|
||||
var count int64 = 100
|
||||
ctx := context.Background()
|
||||
// 循环遍历获取匹配的键
|
||||
for {
|
||||
// 使用 SCAN 命令获取匹配的键
|
||||
batchKeys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
logger.Errorf("failed to scan keys: %v", err)
|
||||
return keys, err
|
||||
}
|
||||
cursor = nextCursor
|
||||
keys = append(keys, batchKeys...)
|
||||
// 当 cursor 为 0,表示遍历完成
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量获得缓存数据
|
||||
func GetBatch(source string, keys []string) ([]any, error) {
|
||||
result := make([]any, 0)
|
||||
if len(keys) == 0 {
|
||||
return result, fmt.Errorf("not keys")
|
||||
}
|
||||
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return result, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
// 获取缓存数据
|
||||
v, err := rdb.MGet(context.Background(), keys...).Result()
|
||||
if err != nil || errors.Is(err, redis.Nil) {
|
||||
logger.Errorf("failed to get batch data: %v", err)
|
||||
return result, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Get 获得缓存数据
|
||||
func Get(source, key string) (string, error) {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return "", fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
v, err := rdb.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", fmt.Errorf("no keys")
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Has 判断是否存在
|
||||
func Has(source string, keys ...string) (int64, error) {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return 0, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
exists, err := rdb.Exists(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Set 设置缓存数据
|
||||
func Set(source, key string, value any) error {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := rdb.Set(ctx, key, value, 0).Err()
|
||||
if err != nil {
|
||||
logger.Errorf("redis Set err %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetByExpire 设置缓存数据与过期时间
|
||||
func SetByExpire(source, key string, value any, expiration time.Duration) error {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := rdb.Set(ctx, key, value, expiration).Err()
|
||||
if err != nil {
|
||||
logger.Errorf("redis SetByExpire err %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Del 删除单个
|
||||
func Del(source string, key string) error {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
logger.Errorf("redis Del err %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DelKeys 删除多个
|
||||
func DelKeys(source string, keys []string) error {
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("no keys")
|
||||
}
|
||||
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
logger.Errorf("redis DelKeys err %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RateLimit 限流查询并记录
|
||||
func RateLimit(source, limitKey string, time, count int64) (int64, error) {
|
||||
// 数据源
|
||||
rdb := RDB(source)
|
||||
if rdb == nil {
|
||||
return 0, fmt.Errorf("redis not client")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := rateLimitCommand.Run(ctx, rdb, []string{limitKey}, time, count).Result()
|
||||
if err != nil {
|
||||
logger.Errorf("redis lua script err %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return result.(int64), err
|
||||
}
|
||||
|
||||
// 声明定义限流脚本命令
|
||||
var rateLimitCommand = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local time = tonumber(ARGV[1])
|
||||
local count = tonumber(ARGV[2])
|
||||
local current = redis.call('get', key);
|
||||
if current and tonumber(current) >= count then
|
||||
return tonumber(current);
|
||||
end
|
||||
current = redis.call('incr', key)
|
||||
if tonumber(current) == 1 then
|
||||
redis.call('expire', key, time)
|
||||
end
|
||||
return tonumber(current);`)
|
||||
Reference in New Issue
Block a user