Created
June 19, 2020 07:59
-
-
Save vardius/d06c255982a0f95261956d06a5774eb8 to your computer and use it in GitHub Desktop.
How to rate limit HTTP requests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"expvar" | |
"fmt" | |
"net" | |
"net/http" | |
"net/http/httptest" | |
"sync" | |
"testing" | |
"time" | |
"golang.org/x/time/rate" | |
) | |
// m contains the global program counters for the application. | |
var m = struct { | |
rl *expvar.Map | |
}{ | |
rl: expvar.NewMap("rateLimits"), | |
} | |
// IpAddress returns client ip address from request | |
// Will check X-Real-IP and X-Forwarded-For header. | |
// Unless you have a trusted reverse proxy, you shouldn't use this function, the client can set headers to any arbitrary value it wants | |
func IpAddress(r *http.Request) (net.IP, error) { | |
addr := r.RemoteAddr | |
if xReal := r.Header.Get("X-Real-Ip"); xReal != "" { | |
addr = xReal | |
} else if xForwarded := r.Header.Get("X-Forwarded-For"); xForwarded != "" { | |
addr = xForwarded | |
} | |
ip, _, err := net.SplitHostPort(addr) | |
if err != nil { | |
return nil, fmt.Errorf("addr: %q is not IP:port", addr) | |
} | |
userIP := net.ParseIP(ip) | |
if userIP == nil { | |
return nil, fmt.Errorf("ip: %q is not a valid IP address", ip) | |
} | |
return userIP, nil | |
} | |
// RateLimit returns a new HTTP middleware that allows request per visitor (IP) | |
// up to rate r and permits bursts of at most b tokens. | |
func RateLimit(r rate.Limit, b int, frequency time.Duration) func(next http.Handler) http.Handler { | |
return func(next http.Handler) http.Handler { | |
if r == rate.Inf { | |
return next | |
} | |
rl := &rateLimiter{ | |
rate: r, | |
burst: b, | |
visitors: make(map[string]*visitor), | |
} | |
go rl.cleanup(frequency) | |
fn := func(w http.ResponseWriter, r *http.Request) { | |
ip, err := IpAddress(r) | |
if err != nil { | |
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) | |
return | |
} | |
if rl.allow(string(ip)) { | |
next.ServeHTTP(w, r) | |
} | |
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | |
return | |
} | |
return http.HandlerFunc(fn) | |
} | |
} | |
type visitor struct { | |
*rate.Limiter | |
lastSeen time.Time | |
} | |
type rateLimiter struct { | |
sync.RWMutex | |
burst int | |
rate rate.Limit | |
visitors map[string]*visitor | |
} | |
// allow checks if given ip has not exceeded rate limit | |
func (l *rateLimiter) allow(ip string) bool { | |
l.RLock() | |
v, exists := l.visitors[ip] | |
l.RUnlock() | |
if !exists { | |
v = &visitor{ | |
Limiter: rate.NewLimiter(l.rate, l.burst), | |
} | |
l.Lock() | |
l.visitors[ip] = v | |
l.Unlock() | |
} | |
v.lastSeen = time.Now() | |
m.rl.Add(ip, 1) | |
return v.Allow() | |
} | |
// cleanup deletes old entries | |
func (l *rateLimiter) cleanup(frequency time.Duration) { | |
for { | |
time.Sleep(frequency) | |
l.Lock() | |
for ip, v := range l.visitors { | |
if time.Since(v.lastSeen) > frequency { | |
delete(l.visitors, ip) | |
m.rl.Delete(ip) | |
} | |
} | |
l.Unlock() | |
} | |
} | |
func TestRateLimit(t *testing.T) { | |
m := RateLimit(1, 1, time.Minute) | |
h := m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
w.WriteHeader(http.StatusOK) | |
})) | |
w := httptest.NewRecorder() | |
req, err := http.NewRequest("GET", "/x", nil) | |
if err != nil { | |
t.Fatal(err) | |
} | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusInternalServerError { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusInternalServerError) | |
} | |
req.RemoteAddr = fmt.Sprintf("%s:%d", httptest.DefaultRemoteAddr, 8080) | |
w = httptest.NewRecorder() | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusOK { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusOK) | |
} | |
w = httptest.NewRecorder() | |
h.ServeHTTP(w, req) | |
if w.Code != http.StatusTooManyRequests { | |
t.Errorf("Request rate limit: %d, expected %d", w.Code, http.StatusTooManyRequests) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment