Skip to content

Instantly share code, notes, and snippets.

@vardius
Created June 19, 2020 07:59
Show Gist options
  • Save vardius/d06c255982a0f95261956d06a5774eb8 to your computer and use it in GitHub Desktop.
Save vardius/d06c255982a0f95261956d06a5774eb8 to your computer and use it in GitHub Desktop.
How to rate limit HTTP requests
package main
import (
"expvar"
"fmt"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
)
// m contains the global program counters for the application.
var m = struct {
rl *expvar.Map
}{
rl: expvar.NewMap("rateLimits"),
}
// IpAddress returns client ip address from request
// Will check X-Real-IP and X-Forwarded-For header.
// Unless you have a trusted reverse proxy, you shouldn't use this function, the client can set headers to any arbitrary value it wants
func IpAddress(r *http.Request) (net.IP, error) {
addr := r.RemoteAddr
if xReal := r.Header.Get("X-Real-Ip"); xReal != "" {
addr = xReal
} else if xForwarded := r.Header.Get("X-Forwarded-For"); xForwarded != "" {
addr = xForwarded
}
ip, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("addr: %q is not IP:port", addr)
}
userIP := net.ParseIP(ip)
if userIP == nil {
return nil, fmt.Errorf("ip: %q is not a valid IP address", ip)
}
return userIP, nil
}
// RateLimit returns a new HTTP middleware that allows request per visitor (IP)
// up to rate r and permits bursts of at most b tokens.
func RateLimit(r rate.Limit, b int, frequency time.Duration) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if r == rate.Inf {
return next
}
rl := &rateLimiter{
rate: r,
burst: b,
visitors: make(map[string]*visitor),
}
go rl.cleanup(frequency)
fn := func(w http.ResponseWriter, r *http.Request) {
ip, err := IpAddress(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if rl.allow(string(ip)) {
next.ServeHTTP(w, r)
}
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
return http.HandlerFunc(fn)
}
}
type visitor struct {
*rate.Limiter
lastSeen time.Time
}
type rateLimiter struct {
sync.RWMutex
burst int
rate rate.Limit
visitors map[string]*visitor
}
// allow checks if given ip has not exceeded rate limit
func (l *rateLimiter) allow(ip string) bool {
l.RLock()
v, exists := l.visitors[ip]
l.RUnlock()
if !exists {
v = &visitor{
Limiter: rate.NewLimiter(l.rate, l.burst),
}
l.Lock()
l.visitors[ip] = v
l.Unlock()
}
v.lastSeen = time.Now()
m.rl.Add(ip, 1)
return v.Allow()
}
// cleanup deletes old entries
func (l *rateLimiter) cleanup(frequency time.Duration) {
for {
time.Sleep(frequency)
l.Lock()
for ip, v := range l.visitors {
if time.Since(v.lastSeen) > frequency {
delete(l.visitors, ip)
m.rl.Delete(ip)
}
}
l.Unlock()
}
}
func TestRateLimit(t *testing.T) {
m := RateLimit(1, 1, time.Minute)
h := m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/x", nil)
if err != nil {
t.Fatal(err)
}
h.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusInternalServerError)
}
req.RemoteAddr = fmt.Sprintf("%s:%d", httptest.DefaultRemoteAddr, 8080)
w = httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusOK)
}
w = httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusTooManyRequests)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment