HTTP token bucket rate limiting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"log" | |
"net/http" | |
"sync" | |
"time" | |
"github.com/chain-engineering/papi/Godeps/_workspace/src/golang.org/x/net/context" | |
"github.com/chain-engineering/papi/api" | |
) | |
const ( | |
bucketSize = 100 | |
reqPerSec = 50 | |
) | |
type limCounter struct { | |
m sync.Mutex | |
n int64 | |
t int64 | |
} | |
func (lc *limCounter) try(t int64) (ok bool) { | |
lc.m.Lock() | |
defer lc.m.Unlock() | |
d := t - lc.t | |
if d < 0 { | |
// Taking the abs value of d | |
// is motivated by paranoia that | |
// t might not always be monotonically increasing. | |
d = -d | |
} | |
lc.n += reqPerSec * d | |
if lc.n > bucketSize { | |
lc.n = bucketSize | |
} | |
lc.t = t | |
if lc.n < 1 { | |
return false | |
} | |
lc.n-- | |
return true | |
} | |
func limitKeyIP(r *http.Request) string { | |
if ip := r.Header.Get("Chain-Forwarded-For"); ip != "" { | |
return ip | |
} | |
return r.Header.Get("X-Forwarded-For") | |
} | |
func limitKeyAuth(r *http.Request) string { | |
id, secret := getAuthToken(r) | |
return id + ":" + secret | |
} | |
var ( | |
limCtrsMu sync.Mutex //protects the following: | |
limCtrs = map[string]*limCounter{} | |
) | |
func init() { | |
go func() { | |
for range time.Tick(time.Hour) { | |
limCtrsMu.Lock() | |
for k, ctr := range limCtrs { | |
ctr.m.Lock() | |
n := ctr.n | |
ctr.m.Unlock() | |
if n > bucketSize { | |
delete(limCtrs, k) | |
} | |
} | |
limCtrsMu.Unlock() | |
} | |
}() | |
} | |
type limitHandler struct { | |
h handler | |
f func(r *http.Request) string | |
} | |
func (l limitHandler) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r *http.Request) { | |
k := l.f(r) | |
limCtrsMu.Lock() | |
ctr, ok := limCtrs[k] | |
if !ok { | |
ctr = &limCounter{n: bucketSize, t: time.Now().Unix()} | |
limCtrs[k] = ctr | |
} | |
limCtrsMu.Unlock() | |
ok = ctr.try(time.Now().Unix()) | |
if !ok { | |
nRateLimit.Add() | |
ip := r.Header.Get("X-Forwarded-For") | |
user, _ := getAuthToken(r) | |
log.Println("rate-limit:", ip, user) | |
httpError(ctx, w, api.ErrRateLimit) | |
return | |
} | |
l.h.ServeHTTPContext(ctx, w, r) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"net/http" | |
"net/http/httptest" | |
"testing" | |
"time" | |
"github.com/chain-engineering/papi/Godeps/_workspace/src/golang.org/x/net/context" | |
) | |
func TestLimitCounter(t *testing.T) { | |
t0 := time.Now().Unix() | |
t1 := t0 + 1 | |
t2 := t0 + 60 | |
cases := []struct { | |
ctr *limCounter | |
t int64 | |
n int64 | |
want bool | |
}{ | |
{&limCounter{n: bucketSize, t: t0}, t1, 99, true}, | |
{&limCounter{n: 50, t: t0}, t0, 49, true}, | |
{&limCounter{n: 50, t: t0}, t1, 99, true}, | |
{&limCounter{n: 0, t: t0}, t0, 0, false}, | |
{&limCounter{n: 0, t: t0}, t2, 99, true}, | |
} | |
for _, test := range cases { | |
if g := test.ctr.try(test.t); g != test.want { | |
t.Errorf("%v.try(%v) = %v want %v", test.ctr, test.t, g, test.want) | |
} | |
if test.n != test.ctr.n { | |
t.Errorf("ctr.n = %v want %v", test.ctr.n, test.n) | |
} | |
} | |
} | |
func TestLimitHandler(t *testing.T) { | |
limCtrs = map[string]*limCounter{} | |
defer func() { | |
limCtrs = map[string]*limCounter{} | |
}() | |
h := limitHandler{ | |
h: valHandler{"ok"}, | |
f: limitKeyAuth, | |
} | |
ctx := context.Background() | |
failed := 0 | |
for i := 0; i < bucketSize+1; i++ { | |
w := httptest.NewRecorder() | |
r, _ := http.NewRequest("GET", "/foo", nil) | |
r.Header.Set("Authorization", "1") | |
h.ServeHTTPContext(ctx, w, r) | |
if w.Code == 429 { | |
failed++ | |
} | |
} | |
if failed != 1 { | |
t.Fatal("Expected request to fail") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment