How to rate limit HTTP requests
package main | |
import ( | |
"expvar" | |
"fmt" | |
"net" | |
"net/http" | |
"net/http/httptest" | |
"sync" | |
"testing" | |
"time" | |
"golang.org/x/time/rate" | |
) | |
// m contains the global program counters for the application. | |
var m = struct { | |
rl *expvar.Map | |
}{ | |
rl: expvar.NewMap("rateLimits"), | |
} | |
// IpAddress returns client ip address from request | |
// Will check X-Real-IP and X-Forwarded-For header. | |
// Unless you have a trusted reverse proxy, you shouldn't use this function, the client can set headers to any arbitrary value it wants | |
func IpAddress(r *http.Request) (net.IP, error) { | |
addr := r.RemoteAddr | |
if xReal := r.Header.Get("X-Real-Ip"); xReal != "" { | |
addr = xReal | |
} else if xForwarded := r.Header.Get("X-Forwarded-For"); xForwarded != "" { | |
addr = xForwarded | |
} | |
ip, _, err := net.SplitHostPort(addr) | |
if err != nil { | |
return nil, fmt.Errorf("addr: %q is not IP:port", addr) | |
} | |
userIP := net.ParseIP(ip) | |
if userIP == nil { | |
return nil, fmt.Errorf("ip: %q is not a valid IP address", ip) | |
} | |
return userIP, nil | |
} | |
// RateLimit returns a new HTTP middleware that allows request per visitor (IP) | |
// up to rate r and permits bursts of at most b tokens. | |
func RateLimit(r rate.Limit, b int, frequency time.Duration) func(next http.Handler) http.Handler { | |
return func(next http.Handler) http.Handler { | |
if r == rate.Inf { | |
return next | |
} | |
rl := &rateLimiter{ | |
rate: r, | |
burst: b, | |
visitors: make(map[string]*visitor), | |
} | |
go rl.cleanup(frequency) | |
fn := func(w http.ResponseWriter, r *http.Request) { | |
ip, err := IpAddress(r) | |
if err != nil { | |
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) | |
return | |
} | |
if rl.allow(string(ip)) { | |
next.ServeHTTP(w, r) | |
} | |
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | |
return | |
} | |
return http.HandlerFunc(fn) | |
} | |
} | |
type visitor struct { | |
*rate.Limiter | |
lastSeen time.Time | |
} | |
type rateLimiter struct { | |
sync.RWMutex | |
burst int | |
rate rate.Limit | |
visitors map[string]*visitor | |
} | |
// allow checks if given ip has not exceeded rate limit | |
func (l *rateLimiter) allow(ip string) bool { | |
l.RLock() | |
v, exists := l.visitors[ip] | |
l.RUnlock() | |
if !exists { | |
v = &visitor{ | |
Limiter: rate.NewLimiter(l.rate, l.burst), | |
} | |
l.Lock() | |
l.visitors[ip] = v | |
l.Unlock() | |
} | |
v.lastSeen = time.Now() | |
m.rl.Add(ip, 1) | |
return v.Allow() | |
} | |
// cleanup deletes old entries | |
func (l *rateLimiter) cleanup(frequency time.Duration) { | |
for { | |
time.Sleep(frequency) | |
l.Lock() | |
for ip, v := range l.visitors { | |
if time.Since(v.lastSeen) > frequency { | |
delete(l.visitors, ip) | |
m.rl.Delete(ip) | |
} | |
} | |
l.Unlock() | |
} | |
} | |
func TestRateLimit(t *testing.T) { | |
m := RateLimit(1, 1, time.Minute) | |
h := m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
w.WriteHeader(http.StatusOK) | |
})) | |
w := httptest.NewRecorder() | |
req, err := http.NewRequest("GET", "/x", nil) | |
if err != nil { | |
t.Fatal(err) | |
} | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusInternalServerError { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusInternalServerError) | |
} | |
req.RemoteAddr = fmt.Sprintf("%s:%d", httptest.DefaultRemoteAddr, 8080) | |
w = httptest.NewRecorder() | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusOK { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusOK) | |
} | |
w = httptest.NewRecorder() | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusTooManyRequests { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusTooManyRequests) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment