Skip to content

Instantly share code, notes, and snippets.

@nekomeowww
Last active April 21, 2022 09:55
Show Gist options
  • Save nekomeowww/e8244e624b60719ddd2d25297ca5813f to your computer and use it in GitHub Desktop.
Save nekomeowww/e8244e624b60719ddd2d25297ca5813f to your computer and use it in GitHub Desktop.
限流中间件
package middleware
type responseBuffer struct {
gin.ResponseWriter // the actual ResponseWriter to flush to
Body *bytes.Buffer // the response content body
Flushed bool
}
func (w *responseBuffer) Write(b []byte) (int, error) {
return w.Body.Write(b)
}
func couldCountRate(method, endpoint string, rate int, perDuration time.Duration, clientIP string) (int, bool) {
// 设定 Redis 存储的键名
keys := keys.RateLimitLock1.Format(method, endpoint, clientIP)
var err error
var countedRate int
// 获取已记录的请求频率
countedRate, err = database.Redis.Get(keys).Int()
// 如果遇到错误
if err != nil {
// 则初始化请求频率为 0
countedRate = 0
// 如果不是缓存不存在的错误,记录日志
if !database.IsRedisNil(err) {
logger.Error(err)
}
}
// 如果频率超过限制
if countedRate >= rate {
// 返回当前的请求频率和 false,表示限流
return countedRate, false
}
// 否则,增加请求频率计数
countedRate++
// 将请求频率计数写入 Redis
err = database.Redis.Set(keys, countedRate, perDuration).Err()
if err != nil {
// 如果遇到错误,记录日志
logger.Error(err)
}
return countedRate, true
}
// LimitRateFor 限流中间件
func LimitRateFor(method, endpoint string, rate int, perDuration time.Duration, shoulfCountFunc func(userID int64, isAborted bool) bool) func(c *gin.Context) {
return func(c *gin.Context) {
// 创建 hanlder 包定义的 Context 实例
ctx := &handler.Context{Context: c}
// 获取请求的 URL 和请求方法,并于限流的 URL 和请求方法进行比较
if ctx.Request.Method != method || (ctx.Request.URL != nil && ctx.Request.URL.String() != endpoint) {
return
}
// 定义 responseBuffer 实例
var bodyWriteBuffer *responseBuffer
// 获取原始的 ResponseWriter
originalWriter, ok := ctx.Writer.(gin.ResponseWriter)
if ok {
// 创建 responseBuffer 实例
bodyWriteBuffer = &responseBuffer{ResponseWriter: originalWriter, Body: &bytes.Buffer{}}
// 覆盖原始的 ResponseWriter
c.Writer = bodyWriteBuffer
// 等待后续中间件进行处理
ctx.Next()
} else {
// 如果不是 ResponseWriter 实例,则直接跳过,退出限流中间件
ctx.Next()
return
}
// 执行回调函数来判断是否需要限流
if !shoulfCountFunc(ctx.User().UserID, ctx.IsAborted()) {
bodyWriteBuffer.ResponseWriter.Write(bodyWriteBuffer.Body.Bytes())
return
}
// 是否能够继续进行计数,即是否允许进行请求
currentRate, ok := couldCountRate(method, endpoint, rate, perDuration, c.ClientIP())
// 可以的话直接进行请求写入
if ok {
bodyWriteBuffer.ResponseWriter.Write(bodyWriteBuffer.Body.Bytes())
return
}
// 否则返回限流错误信息
// 清空原有的 Body buffer
bodyWriteBuffer.Body.Reset()
// HTTP 报文码设定为 StatusTooManyRequests (429)
bodyWriteBuffer.WriteHeader(http.StatusTooManyRequests)
// 序列化限流错误信息
jsonData, _ := json.Marshal(handler.FinalResponse{
Code: apierror.CodeErrRequestRateLimitReached, // 错误码
Data: nil,
Message: apierror.ErrRequestRateLimitReached.FormatMessage(ctx.Language()), // 错误消息
})
// 写入限流错误信息到响应 Body
bodyWriteBuffer.ResponseWriter.Write(jsonData)
// 记录日志
logger.WithFields(logger.Fields{"endpoint": endpoint, "user_id": ctx.User().UserID, "client_ip": c.ClientIP(), "current_rate": currentRate, "rate": rate}).Warn("达到请求频率上限")
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment