Files
be.ems/src/framework/middleware/rate_limit.go
2025-06-07 16:32:04 +08:00

109 lines
3.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"fmt"
"time"
"github.com/gin-gonic/gin"
"be.ems/src/framework/constants"
"be.ems/src/framework/database/redis"
"be.ems/src/framework/ip2region"
"be.ems/src/framework/reqctx"
"be.ems/src/framework/resp"
"be.ems/src/framework/utils/crypto"
)
const (
// LIMIT_GLOBAL 默认策略全局限流
LIMIT_GLOBAL = 1
// LIMIT_IP 根据请求者IP进行限流
LIMIT_IP = 2
// LIMIT_USER 根据用户ID进行限流
LIMIT_USER = 3
)
// LimitOption 请求限流参数
type LimitOption struct {
Time int64 `json:"time"` // 限流时间,单位秒 5
Count int64 `json:"count"` // 限流次数,单位次 10
Type int64 `json:"type"` // 限流条件类型,默认LIMIT_GLOBAL
}
// RateLimit 请求限流
//
// 示例参数middleware.LimitOption{ Time: 5, Count: 10, Type: middleware.LIMIT_IP }
//
// 参数表示5秒内最多请求10次限制类型为 IP
//
// 使用 USER 时,请在用户身份授权认证校验后使用
// 以便获取登录用户信息,无用户信息时默认为 LIMIT_GLOBAL
func RateLimit(option LimitOption) gin.HandlerFunc {
return func(c *gin.Context) {
// 初始可选参数数据
if option.Time < 5 {
option.Time = 5
}
if option.Count < 10 {
option.Count = 10
}
if option.Type == 0 {
option.Type = LIMIT_GLOBAL
}
// 获取执行函数名称
funcName := c.HandlerName()
// 生成限流key
limitKey := constants.CACHE_RATE_LIMIT + ":" + crypto.MD5(funcName)
// 用户
if option.Type == LIMIT_USER {
userId := reqctx.LoginUserToUserID(c)
if userId <= 0 {
c.JSON(401, resp.CodeMsg(resp.CODE_AUTH_INVALID, "invalid login user information"))
c.Abort() // 停止执行后续的处理函数
return
}
funcMd5 := crypto.MD5(fmt.Sprintf("%d:%s", userId, funcName))
limitKey = constants.CACHE_RATE_LIMIT + ":" + funcMd5
}
// IP
if option.Type == LIMIT_IP {
clientIP := ip2region.ClientIP(c.ClientIP())
funcMd5 := crypto.MD5(fmt.Sprintf("%s:%s", clientIP, funcName))
limitKey = constants.CACHE_RATE_LIMIT + ":" + funcMd5
}
// 在Redis查询并记录请求次数
rateCount, err := redis.RateLimit("", limitKey, option.Time, option.Count)
if err != nil {
c.JSON(200, resp.CodeMsg(resp.CODE_RATELIMIT, resp.MSG_RATELIMIT))
c.Abort() // 停止执行后续的处理函数
return
}
rateTime, err := redis.GetExpire("", limitKey)
if err != nil {
c.JSON(200, resp.CodeMsg(resp.CODE_RATELIMIT, resp.MSG_RATELIMIT))
c.Abort() // 停止执行后续的处理函数
return
}
// 设置响应头中的限流声明字段
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", option.Count)) // 总请求数限制
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", option.Count-rateCount)) // 剩余可用请求数
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix()+rateTime)) // 重置时间戳
if rateCount >= option.Count {
c.JSON(200, resp.CodeMsg(resp.CODE_RATELIMIT, resp.MSG_RATELIMIT))
c.Abort() // 停止执行后续的处理函数
return
}
// 调用下一个处理程序
c.Next()
}
}