Created
October 12, 2011 01:09
-
-
Save dgouldin/1279934 to your computer and use it in GitHub Desktop.
Rolling window rate limit function decorator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import datetime | |
import hashlib | |
import inspect | |
import math | |
import pytz | |
import time | |
import urllib | |
from common.cache import incr, decr | |
def patch_timezone(dt, tzinfo): | |
'''Add tzinfo to dt without changing datetime values''' | |
return datetime.datetime(dt.year, dt.month, dt.day, | |
dt.hour, dt.minute, dt.second, dt.microsecond, tzinfo) | |
def utcnow_with_timezone(): | |
return patch_timezone(datetime.datetime.utcnow(), pytz.utc) | |
def timedelta_to_seconds(delta): | |
return delta.days * 60 * 60 * 24 + delta.seconds | |
class CacheTimebox(object): | |
def __init__(self, cache, cache_key, max_bucket_age, max_buckets=1000, | |
max_bucket_size=None): | |
self.cache = cache | |
self.cache_key = cache_key | |
self.max_bucket_age = max_bucket_age | |
self.bucket_size = int(math.ceil( | |
timedelta_to_seconds(max_bucket_age) / max_buckets)) | |
if max_bucket_size: | |
self.bucket_size = min(int(math.ceil( | |
timedelta_to_seconds(max_bucket_size))), self.bucket_size) | |
# make sure buckets are at least 1s | |
self.bucket_size = max(self.bucket_size, 1) | |
def make_cache_key(self, bucket): | |
return ','.join((self.cache_key, str(bucket))) | |
def _get_bucket(self, dt): | |
epoch = time.mktime(dt.timetuple()) | |
return int(math.floor(epoch / self.bucket_size)) | |
def _get_buckets(self, since=None, until=None): | |
until = until or utcnow_with_timezone() | |
since = since or (until - self.max_bucket_age) | |
return range(self._get_bucket(since), | |
self._get_bucket(until) + 1) | |
def _incr_decr(self, at=None, action='incr'): | |
if action == 'incr': | |
action = incr | |
else: | |
action = decr | |
at = at or utcnow_with_timezone() | |
bucket = self._get_bucket(at) | |
cache_key = self.make_cache_key(bucket) | |
action(self.cache, cache_key, | |
expires=timedelta_to_seconds(self.max_bucket_age)) | |
return self.get(until=at) | |
def incr(self, at=None): | |
return self._incr_decr(at=at) | |
# TODO: does decr make sense? If so, how should it be implemented? | |
#def decr(self, at=None): | |
# return self._incr_decr(at=at, action='decr') | |
def get(self, since=None, until=None): | |
buckets = self._get_buckets(since=since, until=until) | |
keys = [self.make_cache_key(bucket) for bucket in buckets] | |
return sum(self.cache.get_many(keys).values()) | |
def delete(self, since=None, until=None): | |
buckets = self._get_buckets(since=since, until=until) | |
keys = [self.make_cache_key(bucket) for bucket in buckets] | |
self.cache.delete_many(keys) | |
# TODO: | |
# 1. Support both absolute and rolling time windows | |
class RateLimitDummyCache(object): | |
store = {} | |
def get(self, key): | |
return self.store.get(key) | |
def get_many(self, keys): | |
many = {} | |
for key in keys: | |
many[key] = self.get(key) | |
return many | |
def set(self, key, value): | |
self.store[key] = value | |
def delete(self, key): | |
self.store.pop(key, None) | |
def delete_many(self, keys): | |
for key in keys: | |
self.delete(key) | |
class RateLimitException(Exception): | |
def __init__(self, name=''): | |
message = 'Rate limit has been reached for %s' % name | |
super(RateLimitException, self).__init__(message) | |
class RateLimit(object): | |
def __init__(self, num_calls, delta, conditions=None, group_by=None, | |
name=None): | |
self.num_calls = num_calls | |
self.delta = delta | |
self.conditions = conditions or {} | |
self.group_by = group_by or {} | |
self._name = name | |
self.errors = [] | |
self.func = None | |
self.cache = RateLimitDummyCache() | |
@property | |
def name(self): | |
func_name = None | |
if self.func: | |
func_name = '.'.join((self.func.__module__, self.func.func_name)) | |
return self._name or func_name | |
def validate(self): | |
arguments = inspect.getargs(self.func.func_code) | |
missing_conditions = set(self.conditions.keys()).difference( | |
set(arguments.args)) | |
missing_group_by = set(self.group_by.keys()).difference( | |
set(arguments.args)) | |
self.errors = [] | |
if missing_conditions: | |
self.errors.append('Missing condition arguments: %s' % ', '.join( | |
missing_conditions)) | |
if missing_group_by: | |
self.errors.append('Missing group_by arguments: %s' % ', '.join( | |
missing_group_by)) | |
return not self.errors | |
def _get_value(self, argname, *args, **kwargs): | |
arguments = inspect.getargs(self.func.func_code) | |
try: | |
val = args[arguments.args.index(argname)] | |
except (ValueError, IndexError): | |
try: | |
val = kwargs[argname] | |
except KeyError: | |
raise ValueError("arg %s could not be found." % argname) | |
return val | |
def applies(self, *args, **kwargs): | |
applies = True | |
for argname, condition in self.conditions.items(): | |
if not condition(self._get_value(argname, *args, **kwargs)): | |
applies = False | |
break | |
return applies | |
def make_cache_key(self, *args, **kwargs): | |
func_name = '.'.join((self.func.__module__, self.func.func_name)) | |
key_parts = ['RateLimit', self.name] | |
key_args = {} | |
for argname, group_filter in self.group_by.items(): | |
key_args[argname] = str(group_filter(get_value(argname))) | |
key_parts.append(hashlib.md5(urllib.urlencode(key_args)).hexdigest()) | |
return ','.join(key_parts) | |
def get_timebox(self, *args, **kwargs): | |
cache_key = self.make_cache_key(*args, **kwargs) | |
return CacheTimebox(self.cache, cache_key, self.delta) | |
def consume(self, *args, **kwargs): | |
func_name = '.'.join((self.func.__module__, self.func.func_name)) | |
if not self.applies(*args, **kwargs): | |
return | |
timebox = self.get_timebox(*args, **kwargs) | |
if timebox.get() < self.num_calls: | |
return self.num_calls - timebox.incr() | |
else: | |
raise RateLimitException(self.name) | |
def exhaust(self, *args, **kwargs): | |
if not self.applies(*args, **kwargs): | |
return | |
timebox = self.get_timebox(*args, **kwargs) | |
consumed = timebox.get() | |
while consumed < self.num_calls: | |
consumed = timebox.incr() | |
def reset(self, *args, **kwargs): | |
timebox = self.get_timebox(*args, **kwargs) | |
timebox.delete() | |
def rate_limit(rate_limits=None, cache=None): | |
def decorator(func): | |
if not rate_limits: | |
return func | |
for limit in rate_limits: | |
limit.func = func | |
if cache: | |
limit.cache = cache | |
if not isinstance(limit, RateLimit): | |
raise ValueError('All rate_limit items must be RateLimit instances.') | |
if not limit.validate(): | |
raise ValueError('limit %s has errors: \r\n%s' % ( | |
limit.name, | |
'\r\n'.join(limit.errors), | |
)) | |
def inner(*args, **kwargs): | |
applicable_limits = filter(lambda l: l.applies(*args, **kwargs), | |
rate_limits) | |
for limit in applicable_limits: | |
limit.consume(*args, **kwargs) | |
try: | |
val = func(*args, **kwargs) | |
except RateLimitException: | |
for limit in applicable_limits: | |
# TODO: try to determine which limit(s) should be exhausted | |
limit.exhaust(*args, **kwargs) | |
raise RateLimitException( | |
','.join([l.name for l in applicable_limits])) | |
return val | |
return inner | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment