Last active
December 8, 2021 10:10
-
-
Save averykhoo/8f812b9579ad6344ac767f33792a4cc1 to your computer and use it in GitHub Desktop.
rate limiter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from functools import wraps | |
from threading import Condition | |
from threading import Lock | |
from typing import Callable | |
from typing import Union | |
def debounce(timeout: Union[int, float, datetime.timedelta], *, default: Any = None) -> Callable: | |
""" | |
decorator to debounce a function | |
blocks and waits at least some number of seconds before calling the function | |
if there is a new call within that time, all preceding calls will fail and the debounce timer will reset | |
failed calls return None by default | |
""" | |
# sanity check the timeout | |
if isinstance(timeout, datetime.timedelta): | |
timeout = timeout.total_seconds() | |
elif not isinstance(timeout, (int, float)): | |
raise TypeError(timeout) | |
# locks shared across all caller threads | |
lock = Lock() | |
atomic_counter = 0 | |
cond = Condition() | |
def decorator(function): | |
nonlocal lock, atomic_counter | |
@wraps(function) | |
def debounced(*args, **kwargs): | |
nonlocal lock, atomic_counter | |
# increment the lock and tell anyone waiting to give up | |
with lock: | |
atomic_counter = expected_value = atomic_counter + 1 | |
cond.notify_all() | |
# if someone else incremented the counter, fail and return None | |
cond.wait(timeout=timeout) | |
with lock: | |
if atomic_counter != expected_value: | |
return default | |
# otherwise, call the function and return the result | |
return function(*args, **kwargs) | |
return debounced | |
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from collections import defaultdict | |
from collections import deque | |
from threading import BoundedSemaphore | |
from threading import Condition | |
from threading import Timer | |
from typing import DefaultDict | |
from typing import Deque | |
from typing import Hashable | |
from typing import Optional | |
from typing import Type | |
from typing import TypeVar | |
class RateLimitExceeded(RuntimeError): | |
pass | |
class RateLimiter(BoundedSemaphore): | |
""" | |
a semaphore for rate-limiting function calls | |
to be used either via `acquire()` and `release()` or via `with...` syntax | |
allows at most `value` calls every `duration` | |
if `duration` is unspecified, allows `value` concurrent calls | |
""" | |
_value: int | |
_initial_value: int | |
_cond: Condition | |
_timestamps: Deque[datetime.datetime] | |
_duration: datetime.timedelta | |
_resource_name: Optional[str] | |
def __init__(self, | |
value: int = 1, | |
duration: Optional[datetime.timedelta] = None, | |
resource_name: Optional[str] = None, | |
): | |
# sanity checks | |
assert isinstance(value, int) and value > 0, value | |
if duration is not None: | |
assert isinstance(duration, datetime.timedelta), duration | |
assert duration.total_seconds() > 0, duration | |
if resource_name is not None: | |
assert isinstance(resource_name, str), resource_name | |
assert len(resource_name.strip()) > 0, resource_name | |
super().__init__(value) | |
self._duration = duration | |
self._timestamps = deque(maxlen=value) | |
self._resource_name = resource_name | |
def __enter__(self, **kwargs): | |
successfully_acquired = self.acquire(blocking=False) | |
# failed to acquire, raise error | |
if not successfully_acquired: | |
error_message_parts = [f'Exceeded the maximum limit of {self._initial_value}'] | |
plural_s = 's' if self._initial_value != 1 else '' | |
if self._duration is not None: | |
duration = f'{round(self._duration.total_seconds())} seconds' | |
error_message_parts.append(f'request{plural_s} per {duration}') | |
else: | |
error_message_parts.append(f'concurrent request{plural_s}') | |
if self._resource_name is not None: | |
error_message_parts.append(f'for {self._resource_name}') | |
raise RateLimitExceeded(' '.join(error_message_parts)) | |
# successfully acquired, add timestamp | |
elif self._duration is not None: | |
with self._cond: | |
self._timestamps.append(datetime.datetime.now() + self._duration) | |
self._cond.notify() | |
def _release(self): | |
with self._cond: | |
self._value += 1 | |
self._cond.notify() | |
def release(self, n: int = 1) -> None: | |
assert n == 1 # n > 1 only supported from Python 3.9 | |
# sleep until end of required duration, if applicable | |
if self._duration is not None: | |
now = datetime.datetime.now() | |
with self._cond: | |
timestamp_end = self._timestamps.popleft() | |
self._cond.notify() | |
time_to_sleep = max(0.0, (timestamp_end - now).total_seconds()) | |
if time_to_sleep > 0: | |
timer = Timer(time_to_sleep, self._release) | |
timer.setDaemon(True) | |
timer.start() | |
else: | |
self._release() | |
else: | |
self._release() | |
RateLimiterGeneric = TypeVar('RateLimiterGeneric', bound=RateLimiter) | |
def create_resource_rate_limiter(value: int, | |
duration: Optional[datetime.timedelta] = None, | |
include_resource_name: bool = True, | |
rate_limiter_class: Type[RateLimiterGeneric] = RateLimiter, | |
) -> DefaultDict[Hashable, RateLimiterGeneric]: | |
if not include_resource_name: | |
return defaultdict(lambda: rate_limiter_class(value, duration)) | |
class ResourceRateLimiter(defaultdict): | |
def __missing__(self, key: Hashable) -> RateLimiterGeneric: | |
key_str = str(key).strip() or None | |
self[key] = out = rate_limiter_class(value, duration, resource_name=key_str) | |
return out | |
return ResourceRateLimiter() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment