358 lines
7.6 KiB
Go
358 lines
7.6 KiB
Go
package redis
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"ems.agt/lib/core/conf"
|
||
"ems.agt/lib/log"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// Redis连接实例
|
||
var rdbMap = make(map[string]*redis.Client)
|
||
|
||
// 声明定义限流脚本命令
|
||
var rateLimitCommand = redis.NewScript(`
|
||
local key = KEYS[1]
|
||
local time = tonumber(ARGV[1])
|
||
local count = tonumber(ARGV[2])
|
||
local current = redis.call('get', key);
|
||
if current and tonumber(current) >= count then
|
||
return tonumber(current);
|
||
end
|
||
current = redis.call('incr', key)
|
||
if tonumber(current) == 1 then
|
||
redis.call('expire', key, time)
|
||
end
|
||
return tonumber(current);`)
|
||
|
||
// 连接Redis实例
|
||
func Connect() {
|
||
ctx := context.Background()
|
||
// 读取数据源配置
|
||
datasource := conf.Get("redis.dataSource").(map[string]any)
|
||
for k, v := range datasource {
|
||
client := v.(map[string]any)
|
||
// 创建连接
|
||
address := fmt.Sprintf("%s:%d", client["host"], client["port"])
|
||
rdb := redis.NewClient(&redis.Options{
|
||
Addr: address,
|
||
Password: client["password"].(string),
|
||
DB: client["db"].(int),
|
||
})
|
||
// 测试数据库连接
|
||
pong, err := rdb.Ping(ctx).Result()
|
||
if err != nil {
|
||
log.Fatalf("failed error ping redis %s is %v", client["host"], err)
|
||
}
|
||
log.Infof("redis %s %s connection is successful.", client["host"], pong)
|
||
rdbMap[k] = rdb
|
||
}
|
||
}
|
||
|
||
// 关闭Redis实例
|
||
func Close() {
|
||
for _, rdb := range rdbMap {
|
||
if err := rdb.Close(); err != nil {
|
||
log.Errorf("fatal error db close: %s", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取默认实例
|
||
func DefaultRDB() *redis.Client {
|
||
source := conf.Get("redis.defaultDataSourceName").(string)
|
||
return rdbMap[source]
|
||
}
|
||
|
||
// 获取实例
|
||
func RDB(source string) *redis.Client {
|
||
return rdbMap[source]
|
||
}
|
||
|
||
// Info 获取redis服务信息
|
||
func Info(source string) map[string]map[string]string {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
info, err := rdb.Info(ctx).Result()
|
||
if err != nil {
|
||
return map[string]map[string]string{}
|
||
}
|
||
infoObj := make(map[string]map[string]string)
|
||
lines := strings.Split(info, "\r\n")
|
||
label := ""
|
||
for _, line := range lines {
|
||
if strings.Contains(line, "#") {
|
||
label = strings.Fields(line)[len(strings.Fields(line))-1]
|
||
label = strings.ToLower(label)
|
||
infoObj[label] = make(map[string]string)
|
||
continue
|
||
}
|
||
kvArr := strings.Split(line, ":")
|
||
if len(kvArr) >= 2 {
|
||
key := strings.TrimSpace(kvArr[0])
|
||
value := strings.TrimSpace(kvArr[len(kvArr)-1])
|
||
infoObj[label][key] = value
|
||
}
|
||
}
|
||
return infoObj
|
||
}
|
||
|
||
// KeySize 获取redis当前连接可用键Key总数信息
|
||
func KeySize(source string) int64 {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
size, err := rdb.DBSize(ctx).Result()
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
return size
|
||
}
|
||
|
||
// CommandStats 获取redis命令状态信息
|
||
func CommandStats(source string) []map[string]string {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
commandstats, err := rdb.Info(ctx, "commandstats").Result()
|
||
if err != nil {
|
||
return []map[string]string{}
|
||
}
|
||
statsObjArr := make([]map[string]string, 0)
|
||
lines := strings.Split(commandstats, "\r\n")
|
||
for _, line := range lines {
|
||
if !strings.HasPrefix(line, "cmdstat_") {
|
||
continue
|
||
}
|
||
kvArr := strings.Split(line, ":")
|
||
key := kvArr[0]
|
||
valueStr := kvArr[len(kvArr)-1]
|
||
statsObj := make(map[string]string)
|
||
statsObj["name"] = key[8:]
|
||
statsObj["value"] = valueStr[6:strings.Index(valueStr, ",usec=")]
|
||
statsObjArr = append(statsObjArr, statsObj)
|
||
}
|
||
return statsObjArr
|
||
}
|
||
|
||
// 获取键的剩余有效时间(秒)
|
||
func GetExpire(source string, key string) (float64, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
ttl, err := rdb.TTL(ctx, key).Result()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return ttl.Seconds(), nil
|
||
}
|
||
|
||
// 获得缓存数据的key列表
|
||
func GetKeys(source string, pattern string) ([]string, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
// 初始化变量
|
||
var keys []string
|
||
var cursor uint64 = 0
|
||
ctx := context.Background()
|
||
// 循环遍历获取匹配的键
|
||
for {
|
||
// 使用 SCAN 命令获取匹配的键
|
||
batchKeys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||
if err != nil {
|
||
log.Errorf("Failed to scan keys: %v", err)
|
||
return keys, err
|
||
}
|
||
cursor = nextCursor
|
||
keys = append(keys, batchKeys...)
|
||
// 当 cursor 为 0,表示遍历完成
|
||
if cursor == 0 {
|
||
break
|
||
}
|
||
}
|
||
return keys, nil
|
||
}
|
||
|
||
// 批量获得缓存数据
|
||
func GetBatch(source string, keys []string) ([]any, error) {
|
||
if len(keys) == 0 {
|
||
return []any{}, fmt.Errorf("not keys")
|
||
}
|
||
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
// 获取缓存数据
|
||
result, err := rdb.MGet(context.Background(), keys...).Result()
|
||
if err != nil {
|
||
log.Errorf("Failed to get batch data: %v", err)
|
||
return []any{}, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// 获得缓存数据
|
||
func Get(source, key string) (string, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
value, err := rdb.Get(ctx, key).Result()
|
||
if err == redis.Nil || err != nil {
|
||
return "", err
|
||
}
|
||
return value, nil
|
||
}
|
||
|
||
// 获得缓存数据Hash
|
||
func GetHash(source, key string) (map[string]string, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
value, err := rdb.HGetAll(ctx, key).Result()
|
||
if err == redis.Nil || err != nil {
|
||
return map[string]string{}, err
|
||
}
|
||
return value, nil
|
||
}
|
||
|
||
// 判断是否存在
|
||
func Has(source string, keys ...string) (bool, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
exists, err := rdb.Exists(ctx, keys...).Result()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return exists >= 1, nil
|
||
}
|
||
|
||
// 设置缓存数据
|
||
func Set(source, key string, value any) (bool, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
err := rdb.Set(ctx, key, value, 0).Err()
|
||
if err != nil {
|
||
log.Errorf("redis lua script err %v", err)
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// 设置缓存数据与过期时间
|
||
func SetByExpire(source, key string, value any, expiration time.Duration) (bool, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
err := rdb.Set(ctx, key, value, expiration).Err()
|
||
if err != nil {
|
||
log.Errorf("redis lua script err %v", err)
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// 删除单个
|
||
func Del(source string, key string) (bool, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
err := rdb.Del(ctx, key).Err()
|
||
if err != nil {
|
||
log.Errorf("redis lua script err %v", err)
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// 删除多个
|
||
func DelKeys(source string, keys []string) (bool, error) {
|
||
if len(keys) == 0 {
|
||
return false, fmt.Errorf("no keys")
|
||
}
|
||
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
err := rdb.Del(ctx, keys...).Err()
|
||
if err != nil {
|
||
log.Errorf("redis lua script err %v", err)
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// 限流查询并记录
|
||
func RateLimit(source, limitKey string, time, count int64) (int64, error) {
|
||
// 数据源
|
||
rdb := DefaultRDB()
|
||
if source != "" {
|
||
rdb = RDB(source)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
result, err := rateLimitCommand.Run(ctx, rdb, []string{limitKey}, time, count).Result()
|
||
if err != nil {
|
||
log.Errorf("redis lua script err %v", err)
|
||
return 0, err
|
||
}
|
||
return result.(int64), err
|
||
}
|