Last active
September 14, 2018 07:17
-
-
Save thehesiod/407670e6d7b6883d3d416d1f649616de to your computer and use it in GitHub Desktop.
asyncio sliding window rate limiter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asynctest | |
import asyncio | |
import logging | |
from collections import deque | |
class RateLimiter: | |
class Error(Exception): | |
pass | |
def __init__(self, max_rate: int, period_s: float or int, logger: logging.Logger): | |
""" | |
Allows `max_rate` per `period_s`. | |
:param max_rate: number of hits allowed per `period_s` | |
:param period_s: period in seconds | |
:param logger: logger to use | |
""" | |
assert isinstance(max_rate, int) and max_rate > 0 | |
assert period_s > 0 | |
self._max_rate = max_rate | |
self._period_s = period_s | |
self._loop = asyncio.get_event_loop() | |
self._logger = logger | |
self._release_task = None | |
self._release_worker_exception = None | |
self._broken_event = asyncio.Event() # will get set if limiter is broken | |
self._broken_evt_wait_fut = asyncio.ensure_future(self._broken_event.wait()) | |
self._waiters = 0 | |
# We'll initially allow `max_rate` to happen in parallel, and then release | |
# the semaphores as new tasks can be started | |
self._sema = asyncio.Semaphore(max_rate) | |
# we'll push the task end-time to this queue during `__aexit__` | |
self._end_time_q = deque() | |
@property | |
def is_broken(self): | |
return self._broken_event.is_set() | |
async def join(self): | |
""" Will wait until all waiters have finished """ | |
while self._waiters: | |
await asyncio.sleep(1) | |
if not self._broken_evt_wait_fut.done(): | |
# normal operation | |
if self._release_task and not self._release_task.done(): | |
await self._release_task | |
self._broken_evt_wait_fut.cancel() | |
else: | |
# exceptional operation | |
if not self._release_task.done(): | |
self._release_task.cancel() | |
async def __aenter__(self): | |
self._waiters += 1 | |
try: | |
# Wait on which happens first: we acquire the semaphore or the rate-limiter breaks | |
with CancellingTaskCtx(self._sema.acquire()) as acquire_fut: | |
await asyncio.wait((self._broken_evt_wait_fut, acquire_fut), return_when=asyncio.FIRST_COMPLETED) | |
if self._broken_evt_wait_fut.done(): | |
raise self.Error("Error while acquiring semaphore") from self._release_worker_exception | |
finally: | |
self._waiters -= 1 | |
async def _release_worker(self, sleep_s): | |
try: | |
# swapping back/forth at this point is ok because __aexit__ will not swap as it does not yield will detect we're already running | |
await asyncio.sleep(sleep_s) | |
now = time.time() # cache as this call is not cheap | |
# Here we'll release each semaphore that expired its period from when it finished | |
# We have a loop as an optimization against having multipler timers since the timer may be called later | |
# than when wanted, and thus we may have multiple semaphores that we can release. | |
while len(self._end_time_q): | |
oldest_finished_ts = self._end_time_q[0] | |
time_since_finished_ts = now - oldest_finished_ts | |
if time_since_finished_ts >= self._period_s: | |
# if either of these fail, the ratelimiter will be marked as broken and all current and future acquires will raise | |
self._end_time_q.popleft() | |
self._sema.release() | |
else: | |
# swapping here is ok for same reason as above | |
await asyncio.sleep(self._period_s - time_since_finished_ts) | |
now = time.time() # we need to update time after we sleep | |
except BaseException as e: | |
self._logger.exception("Failed while attempting to release semaphores") | |
self._release_worker_exception = e # must set this before we set the event | |
self._broken_event.set() | |
# NOTE: theoretically we could try to "reset" the limiter after flushing the semas | |
raise | |
finally: | |
self._release_task = None # only clear this when we're actually exiting | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
# NOTE: Even if there's a pending exception we have to assume the call counted | |
# It's important no yields occur because this and _release_worker modify self._end_time_q. We're | |
# relying on asyncio behavior of only allowing one task to run at a time as a lock. | |
try: | |
# If this fails you'll permanently decrease your available max_rate by one, and if max_rate == 1 deadlock | |
self._end_time_q.append(time.time()) | |
if not self._release_task: | |
# If there's already a timer we don't need to register a new one because the existing | |
# timer will iterate through all the pending events and re-register if they're not yet releasable. | |
# If this fails you'll deadlock if your max_rate == 1 | |
self._release_task = asyncio.ensure_future(self._release_worker(self._period_s)) | |
except: | |
self._logger.exception("Error registering rate limiter hit, potential for deadlock!!!") | |
raise | |
def _assertRecursiveAlmostEqual(self, first, second, places=None, msg=None, delta=None): | |
"""Fail if the two objects are unequal as determined by their | |
difference rounded to the given number of decimal places | |
(default 7) and comparing to zero, or by comparing that the | |
between the two objects is more than the given delta. | |
Note that decimal places (from zero) are usually not the same | |
as significant digits (measured from the most signficant digit). | |
If the two objects compare equal then they will automatically | |
compare almost equal. | |
""" | |
if type(first) != type(second) and not (isinstance(first, (float, int, complex)) and isinstance(second, (float, int, complex))): | |
return self.assertEqual(first, second) # will raise mis-matched types | |
if isinstance(first, (_none_type, str)): | |
self.assertEqual(first, second) | |
elif isinstance(first, (float, int, complex)): | |
self.assertAlmostEqual(first, second, places, msg, delta) | |
elif isinstance(first, dict): | |
self.assertEqual(set(first.keys()), set(second.keys())) # will raise keys don't match | |
for f_k, f_v in first.items(): | |
try: | |
self.assertRecursiveAlmostEqual(f_v, second[f_k], places, msg, delta) | |
except Exception as e: | |
raise Exception("Error with key: {}".format(f_k)) from e | |
elif isinstance(first, (list, tuple)): | |
if len(first) != len(second): | |
self.assertEqual(first, second) # will raise list don't have same length | |
for idx in range(len(first)): | |
try: | |
self.assertRecursiveAlmostEqual(first[idx], second[idx], places, msg, delta) | |
except Exception as e: | |
raise Exception("Error with index: {}".format(idx)) from e | |
else: | |
assert False # unsupported | |
# Monkeypatch in method | |
unittest.TestCase.assertRecursiveAlmostEqual = _assertRecursiveAlmostEqual | |
class TestRateLimiter(asynctest.TestCase): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
logging.basicConfig(level=logging.INFO) | |
self._logger = logging.getLogger(self.__class__.__name__) | |
self._rl = None | |
async def validate_elapsed(self, coro, num_seconds: float or int): | |
start_s = time.time() | |
try: | |
await asyncio.wait_for(coro, num_seconds) | |
except: | |
print("Error elapsed: {}".format(time.time() - start_s)) | |
raise | |
elapsed_s = time.time() - start_s | |
self.assertAlmostEqual(num_seconds, elapsed_s, delta=0.1) | |
async def tearDown(self): | |
await self._rl.join() | |
@staticmethod | |
async def acquire(rl2, sleep_s=0): | |
start = time.time() | |
async with rl2: | |
wait_s = time.time() - start | |
await asyncio.sleep(sleep_s) | |
return wait_s | |
async def test_rate_limiter1(self): | |
# test sequential | |
self._rl = rl = RateLimiter(3, 2, self._logger) | |
await asyncio.wait_for(self.acquire(rl), 0.01) # 1 | |
await asyncio.wait_for(self.acquire(rl), 0.01) # 2 | |
await asyncio.wait_for(self.acquire(rl), 0.01) # 3 | |
fut = asyncio.Task(self.acquire(rl)) # 4 | |
with self.assertRaises(asyncio.TimeoutError): | |
await asyncio.wait_for(asyncio.shield(fut), 0.1) # 4 (fail) | |
await self.validate_elapsed(fut, 1.95) # 4 (complete) | |
await asyncio.wait_for(self.acquire(rl), 0.01) # 5 | |
await asyncio.wait_for(self.acquire(rl), 0.01) # 6 | |
fut = asyncio.Task(self.acquire(rl)) # 7 | |
with self.assertRaises(asyncio.TimeoutError): | |
await asyncio.wait_for(asyncio.shield(fut), 0.01) # 7 (fail) | |
await self.validate_elapsed(fut, 2) # 7 (complete) | |
async def test_rate_limiter2(self): | |
# test parallel | |
self._rl = rl = RateLimiter(3, 2, self._logger) | |
await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 0.1) for _ in range(3)]) # 1, 2, 3 | |
fut = asyncio.Task(self.acquire(rl)) # 4 | |
# this one should timeout but count | |
with self.assertRaises(asyncio.TimeoutError): | |
await asyncio.wait_for(asyncio.shield(fut), 0.1) # 4 (fail) | |
# these should take 2 seconds and pass | |
await asyncio.gather(self.validate_elapsed(fut, 1.95), *[self.validate_elapsed(self.acquire(rl), 1.95) for _ in range(2)]) | |
async def test_rate_limiter3(self): | |
self._rl = rl = RateLimiter(1, 1, self._logger) | |
await asyncio.wait_for(self.acquire(rl), 0.01) | |
await asyncio.sleep(1.1) | |
await asyncio.wait_for(self.acquire(rl), 0.01) | |
async def test_rate_limiter4(self): | |
self._rl = rl = RateLimiter(3, 2, self._logger) | |
# these won't register the call as having happened 1, 2, 3 seconds after | |
await asyncio.gather(*[self.validate_elapsed(self.acquire(rl, i), i + 0.1) for i in range(1, 4)]) | |
# we've waited 3s, so 1s one is released, 2nd has 1s wait, 3rd has 2s wait | |
times = sorted(await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 3) for _ in range(3)])) | |
self.assertRecursiveAlmostEqual(times, [0, 1, 2], delta=0.1) | |
async def test_rate_limiter5(self): | |
self._rl = rl = RateLimiter(3, 2, self._logger) | |
for i in range(1, 4): | |
asyncio.ensure_future(self.acquire(rl, i)) | |
# in this scenario we'll have to wait 1 + 2, 2 + 2, 3 + 2 seconds | |
times = sorted(await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 5.1) for _ in range(3)])) | |
self.assertRecursiveAlmostEqual(times, [3, 4, 5], delta=0.1) | |
# TODO: add test where we break _release_worker |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment