Last active
July 2, 2021 09:43
-
-
Save DannyMor/99c680c129a29b0ec315fdcaee01b6ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import math | |
import time | |
from contextlib import asynccontextmanager | |
class RateLimiter: | |
def __init__(self, | |
rate_limit: int, | |
concurrency_limit: int) -> None: | |
if not rate_limit or rate_limit < 1: | |
raise ValueError('rate limit must be non zero positive number') | |
if not concurrency_limit or concurrency_limit < 1: | |
raise ValueError('concurrent limit must be non zero positive number') | |
self.rate_limit = rate_limit | |
self.tokens_queue = asyncio.Queue(rate_limit) | |
self.tokens_consumer_task = asyncio.create_task(self.consume_tokens()) | |
self.semaphore = asyncio.Semaphore(concurrency_limit) | |
async def add_token(self) -> None: | |
await self.tokens_queue.put(1) | |
return None | |
async def consume_tokens(self): | |
try: | |
consumption_rate = 1 / self.rate_limit | |
last_consumption_time = 0 | |
while True: | |
if self.tokens_queue.empty(): | |
await asyncio.sleep(consumption_rate) | |
continue | |
current_consumption_time = time.monotonic() | |
total_tokens = self.tokens_queue.qsize() | |
tokens_to_consume = self.get_tokens_amount_to_consume( | |
consumption_rate, | |
current_consumption_time, | |
last_consumption_time, | |
total_tokens | |
) | |
for i in range(0, tokens_to_consume): | |
self.tokens_queue.get_nowait() | |
last_consumption_time = time.monotonic() | |
await asyncio.sleep(consumption_rate) | |
except asyncio.CancelledError: | |
# you can ignore the error here and deal with closing this task later but this is not advised | |
raise | |
except Exception as e: | |
# do something with the error and re-raise | |
raise | |
@staticmethod | |
def get_tokens_amount_to_consume(consumption_rate, current_consumption_time, last_consumption_time, total_tokens): | |
time_from_last_consumption = current_consumption_time - last_consumption_time | |
calculated_tokens_to_consume = math.floor(time_from_last_consumption / consumption_rate) | |
tokens_to_consume = min(total_tokens, calculated_tokens_to_consume) | |
return tokens_to_consume | |
@asynccontextmanager | |
async def throttle(self): | |
await self.semaphore.acquire() | |
await self.add_token() | |
try: | |
yield | |
finally: | |
self.semaphore.release() | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
if exc_type: | |
# log error here and safely close the class | |
pass | |
await self.close() | |
async def close(self) -> None: | |
if self.tokens_consumer_task and not self.tokens_consumer_task.cancelled(): | |
try: | |
self.tokens_consumer_task.cancel() | |
await self.tokens_consumer_task | |
except asyncio.CancelledError: | |
# we ignore this exception but it is good to log and signal the task was cancelled | |
pass | |
except Exception as e: | |
# log here and deal with the exception | |
raise |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment