Skip to content

Instantly share code, notes, and snippets.

@hzj629206
Last active September 28, 2018 08:13
Show Gist options
  • Save hzj629206/e27590c6d27e269e4b587ae47e97f114 to your computer and use it in GitHub Desktop.
Save hzj629206/e27590c6d27e269e4b587ae47e97f114 to your computer and use it in GitHub Desktop.
Django Rate Limiter via Redis
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals, division, print_function
import time
import hashlib
import itertools
import functools
from redis.exceptions import RedisError
from django.http import HttpResponse
try:
from django.core.cache import get_cache
except ImportError: # Django 1.9+
from django.core.cache import caches
def get_cache(backend):
return caches[backend]
LUA_SCRIPT_SMOOTH = '''
--[[
A lua rate limiter script run in redis
use token bucket algorithm.
Algorithm explanation
1. key, use this key to find the token bucket in redis
2. there're several args should be passed in:
intervalPerPermit, time interval in millis between two token permits;
timeNow, timestamp in millis when running this lua script;
limit, the capacity limit of the token bucket;
interval, the time interval in millis of the token bucket;
]] --
local effects = {}
for idx, key in ipairs(KEYS) do
local idxBase = (idx - 1) * 6
local interval = tonumber(ARGV[idxBase + 1])
local capacity = tonumber(ARGV[idxBase + 2])
local nTokens = tonumber(ARGV[idxBase + 3])
local timeNow = tonumber(ARGV[idxBase + 4])
local expire = tonumber(ARGV[idxBase + 5])
local intervalPerPermit = tonumber(ARGV[idxBase + 6])
local bucket = redis.call('hgetall', key)
local burstTokens = nTokens
local currentTokens = -1
local lastRefillTime = timeNow
if table.maxn(bucket) == 0 then
-- first check if bucket not exists, if yes, create a new one with full capacity, then grant access
currentTokens = burstTokens
redis.call('hset', key, 'lastRefillTime', timeNow)
elseif table.maxn(bucket) == 4 then
-- if bucket exists, first we try to refill the token bucket
local lastRefillTime, tokensRemaining = tonumber(bucket[2]), tonumber(bucket[4])
if timeNow > lastRefillTime then
-- if timeNow larger than lastRefillTime, we should refill the token buckets
-- calculate the interval between timeNow and lastRefillTime
-- if the result is bigger than the interval of the token bucket,
-- refill the tokens to capacity capacity;
-- else calculate how much tokens should be refilled
local intervalSinceLast = timeNow - lastRefillTime
if intervalSinceLast > interval then
currentTokens = burstTokens
redis.call('hset', key, 'lastRefillTime', timeNow)
else
local grantedTokens = math.floor(intervalSinceLast / intervalPerPermit)
if grantedTokens > 0 then
-- ajust lastRefillTime, we want shift left the refill time.
local padMillis = math.fmod(intervalSinceLast, intervalPerPermit)
lastRefillTime = timeNow - padMillis
redis.call('hset', key, 'lastRefillTime', lastRefillTime)
end
currentTokens = math.min(grantedTokens + tokensRemaining, capacity)
end
else
-- if not, it means some other operation later than this call made the call first.
-- there is no need to refill the tokens.
currentTokens = tokensRemaining
end
end
assert(currentTokens >= 0)
if expire > 0 then
redis.call('expire', key, expire)
end
if nTokens > currentTokens then
-- we didn't consume any keys
redis.call('hset', key, 'tokensRemaining', currentTokens)
for i, effect in ipairs(effects) do
redis.call('hset', effect[1], 'tokensRemaining', effect[2])
end
return {key, interval, capacity, currentTokens, lastRefillTime}
else
table.insert(effects, {key, currentTokens, nTokens})
end
end
for i, effect in ipairs(effects) do
redis.call('hset', effect[1], 'tokensRemaining', effect[2] - effect[3])
end
return {'', 0, 0, 0, 0}
'''
LUA_SCRIPT = '''
--[[
A lua rate limiter script run in redis
use token bucket algorithm.
Algorithm explanation
1. key, use this key to find the token bucket in redis
2. there're several args should be passed in:
timeNow, timestamp in millis when running this lua script;
capacity, the capacity limit of the token bucket;
interval, the time interval in millis of the token bucket;
]] --
local effects = {}
for idx, key in ipairs(KEYS) do
local idxBase = (idx - 1) * 5
local interval = tonumber(ARGV[idxBase + 1])
local capacity = tonumber(ARGV[idxBase + 2])
local nTokens = tonumber(ARGV[idxBase + 3])
local timeNow = tonumber(ARGV[idxBase + 4])
local expire = tonumber(ARGV[idxBase + 5])
local currentTokens = -1
local lastRefillTime = timeNow
if redis.call('exists', key) == 0 then
currentTokens = capacity
redis.call('hset', key, 'lastRefillTime', timeNow)
else
lastRefillTime = tonumber(redis.call('hget', key, 'lastRefillTime'))
if timeNow - lastRefillTime > interval then
currentTokens = capacity
redis.call('hset', key, 'lastRefillTime', timeNow)
else
currentTokens = tonumber(redis.call('hget', key, 'tokensRemaining'))
if currentTokens > capacity then
currentTokens = capacity
end
end
end
assert(currentTokens >= 0)
if expire > 0 then
redis.call('expire', key, expire)
end
if nTokens > currentTokens then
redis.call('hset', key, 'tokens', currentTokens)
for i, effect in ipairs(effects) do
redis.call('hset', effect[1], 'tokensRemaining', effect[2])
end
return {key, interval, capacity, currentTokens, lastRefillTime}
else
table.insert(effects, {key, currentTokens, nTokens})
end
end
for i, effect in ipairs(effects) do
redis.call('hset', effect[1], 'tokensRemaining', effect[2] - effect[3])
end
return {'', 0, 0, 0, 0}
'''
LUA_SCRIPT_SHA1 = hashlib.sha1(LUA_SCRIPT).hexdigest()
LUA_SCRIPT_SMOOTH_SHA1 = hashlib.sha1(LUA_SCRIPT_SMOOTH).hexdigest()
class RedisConsumeDenied(object):
def __init__(self, redis_rv):
self.redis_key = redis_rv[0]
self.interval = redis_rv[1] / 1000
self.capacity = redis_rv[2]
self.current_tokens = redis_rv[3]
self.last_fill_at = redis_rv[4]
def __repr__(self):
return '<RedisConsumeDenied([{}] interval={}, capacity={}, tokens={})>'.format(
self.redis_key, self.interval, self.capacity, self.current_tokens,
)
class RateLimiter(object):
redis_cache = get_cache('redis_cache') # CACHES['redis_cache'] in settings.py
def __init__(self, key_prefix=None):
"""
:param str key_prefix:
"""
self.key_prefix = key_prefix or b'rate_limiter'
self.redis_cli = self.redis_cache.get_client(self.key_prefix, write=True)
def make_key(self, key, interval):
return b'{}:{}:{}'.format(self.key_prefix, key, interval)
def now_ms(self):
return int(time.time() * 1000)
def consume_smooth(self, args):
"""
:param list[(str, float|int, int, int)] args:
:rtype: (bool, RedisConsumeDenied)
"""
script_keys = []
script_args = []
the_now_ms = self.now_ms()
for (key, interval, capacity, n) in args:
redis_key = self.make_key(key, interval)
expire = interval * 2 + 15
interval_ms = interval * 1000
interval_per_permit_ms = interval_ms / capacity # type: float
script_keys.append(redis_key)
script_args.extend([interval_ms, capacity, n, the_now_ms, expire, interval_per_permit_ms])
for i in range(3):
try:
rv = self.redis_cli.evalsha(
LUA_SCRIPT_SMOOTH_SHA1, len(script_keys), *(script_keys + script_args)
)
if rv == ['', 0, 0, 0, 0]:
return True, None
else:
return False, RedisConsumeDenied(rv)
except RedisError:
sha1 = self.redis_cli.script_load(LUA_SCRIPT_SMOOTH)
assert sha1 == LUA_SCRIPT_SMOOTH_SHA1
return True, None
def consume_multi(self, args):
"""
:param list[(str, float|int, int, int)] args:
:rtype: (bool, RedisConsumeDenied)
"""
script_keys = []
script_args = []
the_now_ms = self.now_ms()
for (key, interval, capacity, n) in args:
redis_key = self.make_key(key, interval)
expire = interval * 2 + 15
interval_ms = interval * 1000
script_keys.append(redis_key)
script_args.extend([interval_ms, capacity, n, the_now_ms, expire])
for i in range(3):
try:
rv = self.redis_cli.evalsha(
LUA_SCRIPT_SHA1, len(script_keys), *(script_keys + script_args)
)
if rv == ['', 0, 0, 0, 0]:
return True, None
else:
return False, RedisConsumeDenied(rv)
except RedisError:
sha1 = self.redis_cli.script_load(LUA_SCRIPT)
assert sha1 == LUA_SCRIPT_SHA1
return True, None
def consume(self, key, interval, capacity, n=1, smooth=True):
"""
:param str key:
:param float|int interval:
:param int capacity:
:param int n:
:param bool smooth:
:rtype: (bool, RedisConsumeDenied)
"""
if smooth:
return self.consume_smooth([(key, interval, capacity, n)])
else:
return self.consume_multi([(key, interval, capacity, n)])
def dump(self, key, interval):
"""
:param str key:
:param float|int interval:
"""
print(self.redis_cli.hgetall(self.make_key(key, interval)))
class RatePolicy(object):
"""
global rate limit
"""
def __init__(self, interval, capacity):
self.interval = interval
self.capacity = capacity
def make_key(self, request):
"""
:param request:
:rtype: str
"""
return request.path
def groups(self, request):
"""
:param request:
:rtype: list[(str, float|int, int, int)]
"""
return [(self.make_key(request), self.interval, self.capacity, 1)]
_limiter = RateLimiter()
def rate_limit(policies, smooth=True, limiter=None):
"""
Usage: @rate_limit([RatePolicy(1, 1)])
:param list[RatePolicy] policies:
:param bool smooth:
:param RateLimiter limiter:
"""
def _decorator(func):
@functools.wraps(func)
def _func(request, *args, **kwargs):
limit_args = itertools.chain.from_iterable([policy.groups(request) for policy in policies])
limit_args = list(limit_args)
li = limiter or _limiter
bypass, error = True, None
if limit_args:
if smooth:
bypass, error = li.consume_smooth(limit_args)
else:
bypass, error = li.consume_multi(limit_args)
if bypass:
return func(request, *args, **kwargs)
else:
return HttpResponse(status=429)
return _func
return _decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment