feat: 优化UPF流量统计逻辑

This commit is contained in:
TsMask
2025-03-04 14:51:52 +08:00
parent 32630fbb4a
commit 986624c48f
7 changed files with 282 additions and 143 deletions

View File

@@ -0,0 +1,175 @@
package redis
import (
"context"
"errors"
"fmt"
"sync"
"time"
"be.ems/src/framework/logger"
"github.com/redis/go-redis/v9"
)
// 连接Redis实例
func ConnectPush(source string, rdb *redis.Client) {
if rdb == nil {
delete(rdbMap, source)
return
}
rdbMap[source] = rdb
}
// 批量获得缓存数据 [key]result
func GetHashBatch(source string, keys []string) (map[string]map[string]string, error) {
result := make(map[string]map[string]string, 0)
if len(keys) == 0 {
return result, fmt.Errorf("not keys")
}
// 数据源
rdb := RDB(source)
if rdb == nil {
return result, fmt.Errorf("redis not client")
}
// 创建一个有限的并发控制信号通道
sem := make(chan struct{}, 10)
var wg sync.WaitGroup
var mt sync.Mutex
batchSize := 1000
total := len(keys)
if total < batchSize {
batchSize = total
}
for i := 0; i < total; i += batchSize {
wg.Add(1)
go func(start int) {
ctx := context.Background()
// 并发控制,限制同时执行的 Goroutine 数量
sem <- struct{}{}
defer func() {
<-sem
ctx.Done()
wg.Done()
}()
// 检查索引是否越界
end := start + batchSize
if end > total {
end = total
}
pipe := rdb.Pipeline()
for _, key := range keys[start:end] {
pipe.HGetAll(ctx, key)
}
cmds, err := pipe.Exec(ctx)
if err != nil {
logger.Errorf("Failed to get hash batch exec err: %v", err)
return
}
// 将结果添加到 result map 并发访问
mt.Lock()
defer mt.Unlock()
// 处理命令结果
for _, cmd := range cmds {
if cmd.Err() != nil {
logger.Errorf("Failed to get hash batch cmds err: %v", cmd.Err())
continue
}
// 将结果转换为 *redis.StringStringMapCmd 类型
rcmd, ok := cmd.(*redis.MapStringStringCmd)
if !ok {
logger.Errorf("Failed to get hash batch type err: %v", cmd.Err())
continue
}
key := "-"
args := rcmd.Args()
if len(args) > 0 {
key = fmt.Sprint(args[1])
}
result[key] = rcmd.Val()
}
}(i)
}
wg.Wait()
return result, nil
}
// GetHash 获得缓存数据
func GetHash(source, key, field string) (string, error) {
// 数据源
rdb := RDB(source)
if rdb == nil {
return "", fmt.Errorf("redis not client")
}
ctx := context.Background()
v, err := rdb.HGet(ctx, key, field).Result()
if errors.Is(err, redis.Nil) {
return "", fmt.Errorf("no key field")
}
if err != nil {
return "", err
}
return v, nil
}
// SetHash 设置缓存数据
func SetHash(source, key string, value map[string]any) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
ctx := context.Background()
err := rdb.HSet(ctx, key, value).Err()
if err != nil {
logger.Errorf("redis HSet err %v", err)
return err
}
return nil
}
// IncrBy 累加统计数据
func IncrBy(source, key, field string, value int64) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
// 使用HINCRBY命令进行累加统计
ctx := context.Background()
err := rdb.HIncrBy(ctx, key, field, value).Err()
if err != nil {
logger.Errorf("redis HIncrBy err %v", err)
return err
}
return nil
}
// Expire 过期时间设置
func Expire(source, key string, expiration time.Duration) error {
// 数据源
rdb := RDB(source)
if rdb == nil {
return fmt.Errorf("redis not client")
}
// 过期时间设置
ctx := context.Background()
err := rdb.Expire(ctx, key, expiration).Err()
if err != nil {
logger.Errorf("redis HIncrBy err %v", err)
return err
}
return nil
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"be.ems/src/framework/config"
@@ -31,15 +30,6 @@ if tonumber(current) == 1 then
end
return tonumber(current);`)
// 连接Redis实例
func ConnectPush(source string, rdb *redis.Client) {
if rdb == nil {
delete(rdbMap, source)
return
}
rdbMap[source] = rdb
}
// 连接Redis实例
func Connect() {
ctx := context.Background()
@@ -237,105 +227,6 @@ func Get(source, key string) (string, error) {
return value, nil
}
// 获得缓存数据Hash
func GetHash(source, key string) (map[string]string, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
value, err := rdb.HGetAll(ctx, key).Result()
if err == redis.Nil || err != nil {
return map[string]string{}, err
}
return value, nil
}
// 批量获得缓存数据 [key]result
func GetHashBatch(source string, keys []string) (map[string]map[string]string, error) {
result := make(map[string]map[string]string, 0)
if len(keys) == 0 {
return result, fmt.Errorf("not keys")
}
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
// 创建一个有限的并发控制信号通道
sem := make(chan struct{}, 10)
var wg sync.WaitGroup
var mt sync.Mutex
batchSize := 1000
total := len(keys)
if total < batchSize {
batchSize = total
}
for i := 0; i < total; i += batchSize {
wg.Add(1)
go func(start int) {
ctx := context.Background()
// 并发控制,限制同时执行的 Goroutine 数量
sem <- struct{}{}
defer func() {
<-sem
ctx.Done()
wg.Done()
}()
// 检查索引是否越界
end := start + batchSize
if end > total {
end = total
}
pipe := rdb.Pipeline()
for _, key := range keys[start:end] {
pipe.HGetAll(ctx, key)
}
cmds, err := pipe.Exec(ctx)
if err != nil {
logger.Errorf("Failed to get hash batch exec err: %v", err)
return
}
// 将结果添加到 result map 并发访问
mt.Lock()
defer mt.Unlock()
// 处理命令结果
for _, cmd := range cmds {
if cmd.Err() != nil {
logger.Errorf("Failed to get hash batch cmds err: %v", cmd.Err())
continue
}
// 将结果转换为 *redis.StringStringMapCmd 类型
rcmd, ok := cmd.(*redis.MapStringStringCmd)
if !ok {
logger.Errorf("Failed to get hash batch type err: %v", cmd.Err())
continue
}
key := "-"
args := rcmd.Args()
if len(args) > 0 {
key = fmt.Sprint(args[1])
}
result[key] = rcmd.Val()
}
}(i)
}
wg.Wait()
return result, nil
}
// 判断是否存在
func Has(source string, keys ...string) (bool, error) {
// 数据源