Skip to content

Instantly share code, notes, and snippets.

@dgouldin
Created October 12, 2011 01:09
Show Gist options
  • Save dgouldin/1279934 to your computer and use it in GitHub Desktop.
Save dgouldin/1279934 to your computer and use it in GitHub Desktop.
Rolling window rate limit function decorator
from __future__ import division
import datetime
import hashlib
import inspect
import math
import pytz
import time
import urllib
from common.cache import incr, decr
def patch_timezone(dt, tzinfo):
'''Add tzinfo to dt without changing datetime values'''
return datetime.datetime(dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond, tzinfo)
def utcnow_with_timezone():
return patch_timezone(datetime.datetime.utcnow(), pytz.utc)
def timedelta_to_seconds(delta):
return delta.days * 60 * 60 * 24 + delta.seconds
class CacheTimebox(object):
def __init__(self, cache, cache_key, max_bucket_age, max_buckets=1000,
max_bucket_size=None):
self.cache = cache
self.cache_key = cache_key
self.max_bucket_age = max_bucket_age
self.bucket_size = int(math.ceil(
timedelta_to_seconds(max_bucket_age) / max_buckets))
if max_bucket_size:
self.bucket_size = min(int(math.ceil(
timedelta_to_seconds(max_bucket_size))), self.bucket_size)
# make sure buckets are at least 1s
self.bucket_size = max(self.bucket_size, 1)
def make_cache_key(self, bucket):
return ','.join((self.cache_key, str(bucket)))
def _get_bucket(self, dt):
epoch = time.mktime(dt.timetuple())
return int(math.floor(epoch / self.bucket_size))
def _get_buckets(self, since=None, until=None):
until = until or utcnow_with_timezone()
since = since or (until - self.max_bucket_age)
return range(self._get_bucket(since),
self._get_bucket(until) + 1)
def _incr_decr(self, at=None, action='incr'):
if action == 'incr':
action = incr
else:
action = decr
at = at or utcnow_with_timezone()
bucket = self._get_bucket(at)
cache_key = self.make_cache_key(bucket)
action(self.cache, cache_key,
expires=timedelta_to_seconds(self.max_bucket_age))
return self.get(until=at)
def incr(self, at=None):
return self._incr_decr(at=at)
# TODO: does decr make sense? If so, how should it be implemented?
#def decr(self, at=None):
# return self._incr_decr(at=at, action='decr')
def get(self, since=None, until=None):
buckets = self._get_buckets(since=since, until=until)
keys = [self.make_cache_key(bucket) for bucket in buckets]
return sum(self.cache.get_many(keys).values())
def delete(self, since=None, until=None):
buckets = self._get_buckets(since=since, until=until)
keys = [self.make_cache_key(bucket) for bucket in buckets]
self.cache.delete_many(keys)
# TODO:
# 1. Support both absolute and rolling time windows
class RateLimitDummyCache(object):
store = {}
def get(self, key):
return self.store.get(key)
def get_many(self, keys):
many = {}
for key in keys:
many[key] = self.get(key)
return many
def set(self, key, value):
self.store[key] = value
def delete(self, key):
self.store.pop(key, None)
def delete_many(self, keys):
for key in keys:
self.delete(key)
class RateLimitException(Exception):
def __init__(self, name=''):
message = 'Rate limit has been reached for %s' % name
super(RateLimitException, self).__init__(message)
class RateLimit(object):
def __init__(self, num_calls, delta, conditions=None, group_by=None,
name=None):
self.num_calls = num_calls
self.delta = delta
self.conditions = conditions or {}
self.group_by = group_by or {}
self._name = name
self.errors = []
self.func = None
self.cache = RateLimitDummyCache()
@property
def name(self):
func_name = None
if self.func:
func_name = '.'.join((self.func.__module__, self.func.func_name))
return self._name or func_name
def validate(self):
arguments = inspect.getargs(self.func.func_code)
missing_conditions = set(self.conditions.keys()).difference(
set(arguments.args))
missing_group_by = set(self.group_by.keys()).difference(
set(arguments.args))
self.errors = []
if missing_conditions:
self.errors.append('Missing condition arguments: %s' % ', '.join(
missing_conditions))
if missing_group_by:
self.errors.append('Missing group_by arguments: %s' % ', '.join(
missing_group_by))
return not self.errors
def _get_value(self, argname, *args, **kwargs):
arguments = inspect.getargs(self.func.func_code)
try:
val = args[arguments.args.index(argname)]
except (ValueError, IndexError):
try:
val = kwargs[argname]
except KeyError:
raise ValueError("arg %s could not be found." % argname)
return val
def applies(self, *args, **kwargs):
applies = True
for argname, condition in self.conditions.items():
if not condition(self._get_value(argname, *args, **kwargs)):
applies = False
break
return applies
def make_cache_key(self, *args, **kwargs):
func_name = '.'.join((self.func.__module__, self.func.func_name))
key_parts = ['RateLimit', self.name]
key_args = {}
for argname, group_filter in self.group_by.items():
key_args[argname] = str(group_filter(get_value(argname)))
key_parts.append(hashlib.md5(urllib.urlencode(key_args)).hexdigest())
return ','.join(key_parts)
def get_timebox(self, *args, **kwargs):
cache_key = self.make_cache_key(*args, **kwargs)
return CacheTimebox(self.cache, cache_key, self.delta)
def consume(self, *args, **kwargs):
func_name = '.'.join((self.func.__module__, self.func.func_name))
if not self.applies(*args, **kwargs):
return
timebox = self.get_timebox(*args, **kwargs)
if timebox.get() < self.num_calls:
return self.num_calls - timebox.incr()
else:
raise RateLimitException(self.name)
def exhaust(self, *args, **kwargs):
if not self.applies(*args, **kwargs):
return
timebox = self.get_timebox(*args, **kwargs)
consumed = timebox.get()
while consumed < self.num_calls:
consumed = timebox.incr()
def reset(self, *args, **kwargs):
timebox = self.get_timebox(*args, **kwargs)
timebox.delete()
def rate_limit(rate_limits=None, cache=None):
def decorator(func):
if not rate_limits:
return func
for limit in rate_limits:
limit.func = func
if cache:
limit.cache = cache
if not isinstance(limit, RateLimit):
raise ValueError('All rate_limit items must be RateLimit instances.')
if not limit.validate():
raise ValueError('limit %s has errors: \r\n%s' % (
limit.name,
'\r\n'.join(limit.errors),
))
def inner(*args, **kwargs):
applicable_limits = filter(lambda l: l.applies(*args, **kwargs),
rate_limits)
for limit in applicable_limits:
limit.consume(*args, **kwargs)
try:
val = func(*args, **kwargs)
except RateLimitException:
for limit in applicable_limits:
# TODO: try to determine which limit(s) should be exhausted
limit.exhaust(*args, **kwargs)
raise RateLimitException(
','.join([l.name for l in applicable_limits]))
return val
return inner
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment