Merge remote-tracking branch 'origin/main' into multi-tenant

This commit is contained in:
TsMask
2024-12-24 17:24:24 +08:00
119 changed files with 3113 additions and 1328 deletions

View File

@@ -142,7 +142,7 @@ func RunTime() time.Time {
// Get 获取配置信息
//
// Get("framework.name")
// Get("server.port")
func Get(key string) any {
return viper.Get(key)
}

View File

@@ -1,8 +1,3 @@
# 项目信息
framework:
name: "OMC"
version: "2.2412.2"
# 应用服务配置
server:
# 服务端口
@@ -181,6 +176,10 @@ aes:
# 用户配置
user:
# 登录认证,默认打开
loginAuth: true
# 接口加密,默认打开
cryptoApi: true
# 密码
password:
# 密码最大错误次数

View File

@@ -9,6 +9,7 @@ import (
"be.ems/src/framework/config"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/parse"
"gorm.io/driver/mysql"
"gorm.io/gorm"
@@ -117,6 +118,9 @@ func DefaultDB() *gorm.DB {
// 获取数据源
func DB(source string) *gorm.DB {
if source == "" {
source = config.Get("gorm.defaultDataSourceName").(string)
}
return dbMap[source]
}
@@ -159,3 +163,19 @@ func ExecDB(source string, sql string, parameters []any) (int64, error) {
}
return res.RowsAffected, nil
}
// PageNumSize 分页页码记录数
func PageNumSize(pageNum, pageSize any) (int, int) {
// 记录起始索引
num := parse.Number(pageNum)
if num < 1 {
num = 1
}
// 显示记录数
size := parse.Number(pageSize)
if size < 0 {
size = 10
}
return int(num - 1), int(size)
}

View File

@@ -24,6 +24,16 @@ import (
// 请将中间件放在最前置,对请求优先处理
func CryptoApi(requestDecrypt, responseEncrypt bool) gin.HandlerFunc {
return func(c *gin.Context) {
// 登录认证,默认打开
enable := true
if v := config.Get("user.cryptoApi"); v != nil && enable {
enable = v.(bool)
}
if !enable {
c.Next()
return
}
// 请求解密时对请求data注入
if requestDecrypt {
method := c.Request.Method

View File

@@ -3,6 +3,7 @@ package middleware
import (
"strings"
"be.ems/src/framework/config"
AdminConstants "be.ems/src/framework/constants/admin"
commonConstants "be.ems/src/framework/constants/common"
"be.ems/src/framework/i18n"
@@ -36,6 +37,22 @@ var URL_WHITE_LIST = []string{
// 同时匹配其中权限 "matchPerms": {"xxx"},
func PreAuthorize(options map[string][]string) gin.HandlerFunc {
return func(c *gin.Context) {
// 登录认证,默认打开
enable := true
if v := config.Get("user.loginAuth"); v != nil {
enable = v.(bool)
}
if !enable {
loginUser, _ := ctxUtils.LoginUser(c)
loginUser.UserID = "2"
loginUser.User.UserID = "2"
loginUser.User.UserName = "admin"
loginUser.User.NickName = "admin"
c.Set(commonConstants.CTX_LOGIN_USER, loginUser)
c.Next()
return
}
language := ctxUtils.AcceptLanguage(c)
requestURI := c.Request.RequestURI

View File

@@ -59,3 +59,21 @@ func (c *ConnRedis) Close() {
c.Client.Close()
}
}
// RunCMD 执行单次命令 "GET key"
func (c *ConnRedis) RunCMD(cmd string) (any, error) {
if c.Client == nil {
return "", fmt.Errorf("redis client not connected")
}
// 写入命令
cmdArr := strings.Fields(cmd)
if len(cmdArr) == 0 {
return "", fmt.Errorf("redis command is empty")
}
conn := *c.Client
args := make([]any, 0)
for _, v := range cmdArr {
args = append(args, v)
}
return conn.Do(context.Background(), args...).Result()
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"be.ems/src/framework/config"
@@ -179,31 +180,22 @@ func GetExpire(source string, key string) (float64, error) {
}
// 获得缓存数据的key列表
func GetKeys(source string, pattern string) ([]string, error) {
func GetKeys(source string, match string) ([]string, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
// 初始化变量
var keys []string
var cursor uint64 = 0
keys := make([]string, 0)
ctx := context.Background()
// 循环遍历获取匹配的键
for {
// 使用 SCAN 命令获取匹配的键
batchKeys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 1000).Result()
if err != nil {
logger.Errorf("Failed to scan keys: %v", err)
return keys, err
}
cursor = nextCursor
keys = append(keys, batchKeys...)
// 当 cursor 为 0表示遍历完成
if cursor == 0 {
break
}
iter := rdb.Scan(ctx, 0, match, 1000).Iterator()
if err := iter.Err(); err != nil {
logger.Errorf("Failed to scan keys: %v", err)
return keys, err
}
for iter.Next(ctx) {
keys = append(keys, iter.Val())
}
return keys, nil
}
@@ -261,6 +253,89 @@ func GetHash(source, key string) (map[string]string, error) {
return value, nil
}
// 批量获得缓存数据 [key]result
func GetHashBatch(source string, keys []string) (map[string]map[string]string, error) {
result := make(map[string]map[string]string, 0)
if len(keys) == 0 {
return result, fmt.Errorf("not keys")
}
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
// 创建一个有限的并发控制信号通道
sem := make(chan struct{}, 10)
var wg sync.WaitGroup
var mt sync.Mutex
batchSize := 1000
total := len(keys)
if total < batchSize {
batchSize = total
}
for i := 0; i < total; i += batchSize {
wg.Add(1)
go func(start int) {
ctx := context.Background()
// 并发控制,限制同时执行的 Goroutine 数量
sem <- struct{}{}
defer func() {
<-sem
ctx.Done()
wg.Done()
}()
// 检查索引是否越界
end := start + batchSize
if end > total {
end = total
}
pipe := rdb.Pipeline()
for _, key := range keys[start:end] {
pipe.HGetAll(ctx, key)
}
cmds, err := pipe.Exec(ctx)
if err != nil {
logger.Errorf("Failed to get hash batch exec err: %v", err)
return
}
// 将结果添加到 result map 并发访问
mt.Lock()
defer mt.Unlock()
// 处理命令结果
for _, cmd := range cmds {
if cmd.Err() != nil {
logger.Errorf("Failed to get hash batch cmds err: %v", cmd.Err())
continue
}
// 将结果转换为 *redis.StringStringMapCmd 类型
rcmd, ok := cmd.(*redis.MapStringStringCmd)
if !ok {
logger.Errorf("Failed to get hash batch type err: %v", cmd.Err())
continue
}
key := "-"
args := rcmd.Args()
if len(args) > 0 {
key = fmt.Sprint(args[1])
}
result[key] = rcmd.Val()
}
}(i)
}
wg.Wait()
return result, nil
}
// 判断是否存在
func Has(source string, keys ...string) (bool, error) {
// 数据源

View File

@@ -16,6 +16,16 @@ import (
"github.com/gin-gonic/gin"
)
// QueryMapString 查询参数转换Map
func QueryMapString(c *gin.Context) map[string]string {
queryValues := c.Request.URL.Query()
queryParams := make(map[string]string)
for key, values := range queryValues {
queryParams[key] = values[0]
}
return queryParams
}
// QueryMap 查询参数转换Map
func QueryMap(c *gin.Context) map[string]any {
queryValues := c.Request.URL.Query()

View File

@@ -12,8 +12,13 @@ import (
"os"
"strings"
"time"
libGlobal "be.ems/lib/global"
)
// userAgent 自定义 User-Agent
var userAgent = fmt.Sprintf("OMC/%s", libGlobal.Version)
// Get 发送 GET 请求
// timeout 超时时间(毫秒)
func Get(url string, headers map[string]string, timeout int) ([]byte, error) {
@@ -29,6 +34,8 @@ func Get(url string, headers map[string]string, timeout int) ([]byte, error) {
return nil, err
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", "application/json;charset=UTF-8")
for key, value := range headers {
req.Header.Set(key, value)
}
@@ -60,8 +67,8 @@ func Post(url string, data url.Values, headers map[string]string) ([]byte, error
return nil, err
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
for key, value := range headers {
req.Header.Set(key, value)
}
@@ -100,8 +107,8 @@ func PostJSON(url string, data any, headers map[string]string) ([]byte, error) {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", "application/json;charset=UTF-8")
for key, value := range headers {
req.Header.Set(key, value)
}
@@ -156,6 +163,7 @@ func PostUploadFile(url string, params map[string]string, file *os.File) ([]byte
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{}
@@ -193,6 +201,8 @@ func PutJSON(url string, data any, headers map[string]string) ([]byte, error) {
return nil, err
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", "application/json;charset=UTF-8")
for key, value := range headers {
req.Header.Set(key, value)
}
@@ -224,6 +234,8 @@ func Delete(url string, headers map[string]string) ([]byte, error) {
return nil, err
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Content-Type", "application/json;charset=UTF-8")
for key, value := range headers {
req.Header.Set(key, value)
}