Last active
July 21, 2022 02:13
-
-
Save KhanhhNe/bba8cc30c63e9b18fb09f8bd6911cc63 to your computer and use it in GitHub Desktop.
Python thread-safe throttling (rate limiting)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import threading | |
import time | |
from collections import deque | |
from datetime import datetime | |
from functools import wraps | |
from inspect import iscoroutinefunction | |
from typing import cast | |
def calculate_wait_time(last_call: datetime, min_wait: float): | |
""" | |
Calculate minimum wait time until next function call is allowed | |
:param last_call: Time of last function call | |
:param min_wait: Minimum wait time | |
""" | |
seconds_since_last = (datetime.now() - last_call).total_seconds() | |
return max(min_wait - seconds_since_last, 0) | |
def rate_limit(reset_after: float, maximum_calls: int = 1): | |
""" | |
Apply rate limit to wrapped function. This function is thread-safe (both | |
sync and async) | |
:param reset_after: Waiting seconds before function calls are allowed | |
:param maximum_calls: Maximum calls in `reset_after` duration | |
""" | |
def wrapper(func): | |
# This deque holds all function call times, which by default set to | |
# datetime(1000, 1, 1) (a long time ago). The order of the times in | |
# this deque is from latest to oldest | |
call_times = deque(datetime(1000, 1, 1) for _ in range(maximum_calls)) | |
# Making the wrapper thread-safe using both threading.Lock and | |
# asyncio.Lock | |
lock = threading.Lock() | |
async_lock = asyncio.Lock() | |
def wait(): | |
with lock: | |
wait_time = calculate_wait_time(call_times.pop(), reset_after) | |
time.sleep(wait_time) | |
call_times.appendleft(datetime.now()) | |
async def wait_async(): | |
async with async_lock: | |
wait_time = calculate_wait_time(call_times.pop(), reset_after) | |
await asyncio.sleep(wait_time) | |
call_times.appendleft(datetime.now()) | |
@wraps(func) | |
def wrapped(*args, **kwargs): | |
wait() | |
return func(*args, **kwargs) | |
@wraps(func) | |
async def wrapped_async(*args, **kwargs): | |
await wait_async() | |
return await func(*args, **kwargs) | |
# Determine whether the input function is sync or async function and | |
# return the respective wrapped function | |
if iscoroutinefunction(func): | |
return cast(func, wrapped_async) | |
else: | |
return cast(func, wrapped) | |
return wrapper | |
if __name__ == '__main__': | |
# Once every 5 seconds | |
@rate_limit(5) | |
def say_hi(): | |
print('hi') | |
# 3 times every 5 seconds | |
@rate_limit(5, 3) | |
def say_bye(): | |
print('bye') | |
# And async version too! | |
@rate_limit(5) | |
async def say_foo(): | |
print('foo') | |
# Test them out | |
hi_threads = [threading.Thread(target=say_hi) for _ in range(10)] | |
for thread in hi_threads: | |
thread.start() | |
for thread in hi_threads: | |
thread.join() | |
bye_threads = [threading.Thread(target=say_bye) for _ in range(10)] | |
for thread in bye_threads: | |
thread.start() | |
for thread in bye_threads: | |
thread.join() | |
asyncio.get_event_loop().run_until_complete( | |
asyncio.gather(*[say_foo() for _ in range(10)]) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment