Skip to content

Instantly share code, notes, and snippets.

@averykhoo
Last active December 8, 2021 10:10
Show Gist options
  • Save averykhoo/8f812b9579ad6344ac767f33792a4cc1 to your computer and use it in GitHub Desktop.
Save averykhoo/8f812b9579ad6344ac767f33792a4cc1 to your computer and use it in GitHub Desktop.
rate limiter
import datetime
from functools import wraps
from threading import Condition
from threading import Lock
from typing import Callable
from typing import Union
def debounce(timeout: Union[int, float, datetime.timedelta], *, default: Any = None) -> Callable:
"""
decorator to debounce a function
blocks and waits at least some number of seconds before calling the function
if there is a new call within that time, all preceding calls will fail and the debounce timer will reset
failed calls return None by default
"""
# sanity check the timeout
if isinstance(timeout, datetime.timedelta):
timeout = timeout.total_seconds()
elif not isinstance(timeout, (int, float)):
raise TypeError(timeout)
# locks shared across all caller threads
lock = Lock()
atomic_counter = 0
cond = Condition()
def decorator(function):
nonlocal lock, atomic_counter
@wraps(function)
def debounced(*args, **kwargs):
nonlocal lock, atomic_counter
# increment the lock and tell anyone waiting to give up
with lock:
atomic_counter = expected_value = atomic_counter + 1
cond.notify_all()
# if someone else incremented the counter, fail and return None
cond.wait(timeout=timeout)
with lock:
if atomic_counter != expected_value:
return default
# otherwise, call the function and return the result
return function(*args, **kwargs)
return debounced
return decorator
import datetime
from collections import defaultdict
from collections import deque
from threading import BoundedSemaphore
from threading import Condition
from threading import Timer
from typing import DefaultDict
from typing import Deque
from typing import Hashable
from typing import Optional
from typing import Type
from typing import TypeVar
class RateLimitExceeded(RuntimeError):
pass
class RateLimiter(BoundedSemaphore):
"""
a semaphore for rate-limiting function calls
to be used either via `acquire()` and `release()` or via `with...` syntax
allows at most `value` calls every `duration`
if `duration` is unspecified, allows `value` concurrent calls
"""
_value: int
_initial_value: int
_cond: Condition
_timestamps: Deque[datetime.datetime]
_duration: datetime.timedelta
_resource_name: Optional[str]
def __init__(self,
value: int = 1,
duration: Optional[datetime.timedelta] = None,
resource_name: Optional[str] = None,
):
# sanity checks
assert isinstance(value, int) and value > 0, value
if duration is not None:
assert isinstance(duration, datetime.timedelta), duration
assert duration.total_seconds() > 0, duration
if resource_name is not None:
assert isinstance(resource_name, str), resource_name
assert len(resource_name.strip()) > 0, resource_name
super().__init__(value)
self._duration = duration
self._timestamps = deque(maxlen=value)
self._resource_name = resource_name
def __enter__(self, **kwargs):
successfully_acquired = self.acquire(blocking=False)
# failed to acquire, raise error
if not successfully_acquired:
error_message_parts = [f'Exceeded the maximum limit of {self._initial_value}']
plural_s = 's' if self._initial_value != 1 else ''
if self._duration is not None:
duration = f'{round(self._duration.total_seconds())} seconds'
error_message_parts.append(f'request{plural_s} per {duration}')
else:
error_message_parts.append(f'concurrent request{plural_s}')
if self._resource_name is not None:
error_message_parts.append(f'for {self._resource_name}')
raise RateLimitExceeded(' '.join(error_message_parts))
# successfully acquired, add timestamp
elif self._duration is not None:
with self._cond:
self._timestamps.append(datetime.datetime.now() + self._duration)
self._cond.notify()
def _release(self):
with self._cond:
self._value += 1
self._cond.notify()
def release(self, n: int = 1) -> None:
assert n == 1 # n > 1 only supported from Python 3.9
# sleep until end of required duration, if applicable
if self._duration is not None:
now = datetime.datetime.now()
with self._cond:
timestamp_end = self._timestamps.popleft()
self._cond.notify()
time_to_sleep = max(0.0, (timestamp_end - now).total_seconds())
if time_to_sleep > 0:
timer = Timer(time_to_sleep, self._release)
timer.setDaemon(True)
timer.start()
else:
self._release()
else:
self._release()
RateLimiterGeneric = TypeVar('RateLimiterGeneric', bound=RateLimiter)
def create_resource_rate_limiter(value: int,
duration: Optional[datetime.timedelta] = None,
include_resource_name: bool = True,
rate_limiter_class: Type[RateLimiterGeneric] = RateLimiter,
) -> DefaultDict[Hashable, RateLimiterGeneric]:
if not include_resource_name:
return defaultdict(lambda: rate_limiter_class(value, duration))
class ResourceRateLimiter(defaultdict):
def __missing__(self, key: Hashable) -> RateLimiterGeneric:
key_str = str(key).strip() or None
self[key] = out = rate_limiter_class(value, duration, resource_name=key_str)
return out
return ResourceRateLimiter()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment