Skip to content

Instantly share code, notes, and snippets.

@trungdlp-wolffun
Created May 10, 2024 16:38
Show Gist options
  • Save trungdlp-wolffun/fed7b4d3ef1d67273193a02a34faaf10 to your computer and use it in GitHub Desktop.
Save trungdlp-wolffun/fed7b4d3ef1d67273193a02a34faaf10 to your computer and use it in GitHub Desktop.
package echo_limiter
import (
"errors"
"net/http"
"strconv"
"time"
"github.com/go-redis/redis_rate/v10"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/redis/go-redis/v9"
)
const (
DefaultKeyPrefix = "echo_limiter"
defaultMessage = "Too many requests, please try again later."
defaultStatusCode = http.StatusTooManyRequests
)
var (
DefaultConfig = Config{
Skipper: middleware.DefaultSkipper,
Max: 10,
Burst: 10,
StatusCode: defaultStatusCode,
Message: defaultMessage,
Prefix: DefaultKeyPrefix,
Period: time.Minute,
Key: func(ctx echo.Context) string {
return ctx.RealIP()
},
}
)
type (
Config struct {
Skipper middleware.Skipper
// Rediser
Rediser redis.UniversalClient
// Max number of recent connections
// Default: 10
Max int
// Burst
Burst int
// StatusCode
// Default: 429 Too Many Requests
StatusCode int
// Message
// default: "Too many requests, please try again later."
Message string
// Algorithm
// Default: sliding window
Algorithm uint
// Prefix
// Default:
Prefix string
// SkipOnError
// Default: false
SkipOnError bool
// Period
Period time.Duration
// Key allows to use a custom handler to create custom keys
// Default: func(echo.Context) string {
// return ctx.RealIP()
// }
Key func(echo.Context) string
// Handler is called when a request hits the limit
// Default: func(c echo.Context) {
// return ctx.String(defaultStatusCode, defaultMessage)
// }
Handler func(echo.Context) error
// ErrHandler is called when a error happen inside go_limiiter lib
// Default: func(c echo.Context) {
// return ctx.String(defaultStatusCode, defaultMessage)
// }
ErrHandler func(error, echo.Context) error
}
)
func New(rediser redis.UniversalClient) echo.MiddlewareFunc {
config := DefaultConfig
config.Rediser = rediser
return NewWithConfig(config)
}
func NewWithConfig(config Config) echo.MiddlewareFunc {
if config.Rediser == nil {
panic(errors.New("redis client is missing"))
}
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
if config.Max == 0 {
config.Max = DefaultConfig.Max
}
if config.Burst == 0 {
config.Burst = DefaultConfig.Burst
}
if config.StatusCode == 0 {
config.StatusCode = DefaultConfig.StatusCode
}
if config.Message == "" {
config.Message = DefaultConfig.Message
}
if config.Algorithm == 0 {
config.Algorithm = DefaultConfig.Algorithm
}
if config.Prefix == "" {
config.Prefix = DefaultConfig.Prefix
}
if config.Period == 0 {
config.Period = DefaultConfig.Period
}
if config.Key == nil {
config.Key = DefaultConfig.Key
}
if config.Handler == nil {
config.Handler = func(ctx echo.Context) error {
return ctx.String(config.StatusCode, config.Message)
}
}
if config.ErrHandler == nil {
config.ErrHandler = func(err error, ctx echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, err)
}
}
limiter := redis_rate.NewLimiter(config.Rediser)
limit := redis_rate.Limit{
Rate: config.Max,
Burst: config.Burst,
Period: config.Period,
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
if config.Skipper(ctx) {
return next(ctx)
}
result, err := limiter.Allow(ctx.Request().Context(), config.Key(ctx), limit)
if err != nil {
ctx.Logger().Error(err)
if config.SkipOnError {
return next(ctx)
}
return config.ErrHandler(err, ctx)
}
res := ctx.Response()
// Check if hits exceed the max
if result.Allowed == 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
res.Header().Set("Retry-After", strconv.FormatInt(time.Now().Add(result.RetryAfter).Unix(), 10))
// Call Handler func
return config.Handler(ctx)
}
// We can continue, update RateLimit headers
res.Header().Set("X-RateLimit-Limit", strconv.Itoa(config.Max))
res.Header().Set("X-RateLimit-Remaining", strconv.FormatInt(int64(result.Remaining), 10))
res.Header().Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(result.ResetAfter).Unix(), 10))
return next(ctx)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment