Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
HTTP token bucket rate limiting
package main
import (
"log"
"net/http"
"sync"
"time"
"github.com/chain-engineering/papi/Godeps/_workspace/src/golang.org/x/net/context"
"github.com/chain-engineering/papi/api"
)
const (
bucketSize = 100
reqPerSec = 50
)
type limCounter struct {
m sync.Mutex
n int64
t int64
}
func (lc *limCounter) try(t int64) (ok bool) {
lc.m.Lock()
defer lc.m.Unlock()
d := t - lc.t
if d < 0 {
// Taking the abs value of d
// is motivated by paranoia that
// t might not always be monotonically increasing.
d = -d
}
lc.n += reqPerSec * d
if lc.n > bucketSize {
lc.n = bucketSize
}
lc.t = t
if lc.n < 1 {
return false
}
lc.n--
return true
}
func limitKeyIP(r *http.Request) string {
if ip := r.Header.Get("Chain-Forwarded-For"); ip != "" {
return ip
}
return r.Header.Get("X-Forwarded-For")
}
func limitKeyAuth(r *http.Request) string {
id, secret := getAuthToken(r)
return id + ":" + secret
}
var (
limCtrsMu sync.Mutex //protects the following:
limCtrs = map[string]*limCounter{}
)
func init() {
go func() {
for range time.Tick(time.Hour) {
limCtrsMu.Lock()
for k, ctr := range limCtrs {
ctr.m.Lock()
n := ctr.n
ctr.m.Unlock()
if n > bucketSize {
delete(limCtrs, k)
}
}
limCtrsMu.Unlock()
}
}()
}
type limitHandler struct {
h handler
f func(r *http.Request) string
}
func (l limitHandler) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r *http.Request) {
k := l.f(r)
limCtrsMu.Lock()
ctr, ok := limCtrs[k]
if !ok {
ctr = &limCounter{n: bucketSize, t: time.Now().Unix()}
limCtrs[k] = ctr
}
limCtrsMu.Unlock()
ok = ctr.try(time.Now().Unix())
if !ok {
nRateLimit.Add()
ip := r.Header.Get("X-Forwarded-For")
user, _ := getAuthToken(r)
log.Println("rate-limit:", ip, user)
httpError(ctx, w, api.ErrRateLimit)
return
}
l.h.ServeHTTPContext(ctx, w, r)
}
package main
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/chain-engineering/papi/Godeps/_workspace/src/golang.org/x/net/context"
)
func TestLimitCounter(t *testing.T) {
t0 := time.Now().Unix()
t1 := t0 + 1
t2 := t0 + 60
cases := []struct {
ctr *limCounter
t int64
n int64
want bool
}{
{&limCounter{n: bucketSize, t: t0}, t1, 99, true},
{&limCounter{n: 50, t: t0}, t0, 49, true},
{&limCounter{n: 50, t: t0}, t1, 99, true},
{&limCounter{n: 0, t: t0}, t0, 0, false},
{&limCounter{n: 0, t: t0}, t2, 99, true},
}
for _, test := range cases {
if g := test.ctr.try(test.t); g != test.want {
t.Errorf("%v.try(%v) = %v want %v", test.ctr, test.t, g, test.want)
}
if test.n != test.ctr.n {
t.Errorf("ctr.n = %v want %v", test.ctr.n, test.n)
}
}
}
func TestLimitHandler(t *testing.T) {
limCtrs = map[string]*limCounter{}
defer func() {
limCtrs = map[string]*limCounter{}
}()
h := limitHandler{
h: valHandler{"ok"},
f: limitKeyAuth,
}
ctx := context.Background()
failed := 0
for i := 0; i < bucketSize+1; i++ {
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/foo", nil)
r.Header.Set("Authorization", "1")
h.ServeHTTPContext(ctx, w, r)
if w.Code == 429 {
failed++
}
}
if failed != 1 {
t.Fatal("Expected request to fail")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment