Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active September 14, 2018 07:17
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/407670e6d7b6883d3d416d1f649616de to your computer and use it in GitHub Desktop.
Save thehesiod/407670e6d7b6883d3d416d1f649616de to your computer and use it in GitHub Desktop.
asyncio sliding window rate limiter
import asynctest
import asyncio
import logging
from collections import deque
class RateLimiter:
class Error(Exception):
pass
def __init__(self, max_rate: int, period_s: float or int, logger: logging.Logger):
"""
Allows `max_rate` per `period_s`.
:param max_rate: number of hits allowed per `period_s`
:param period_s: period in seconds
:param logger: logger to use
"""
assert isinstance(max_rate, int) and max_rate > 0
assert period_s > 0
self._max_rate = max_rate
self._period_s = period_s
self._loop = asyncio.get_event_loop()
self._logger = logger
self._release_task = None
self._release_worker_exception = None
self._broken_event = asyncio.Event() # will get set if limiter is broken
self._broken_evt_wait_fut = asyncio.ensure_future(self._broken_event.wait())
self._waiters = 0
# We'll initially allow `max_rate` to happen in parallel, and then release
# the semaphores as new tasks can be started
self._sema = asyncio.Semaphore(max_rate)
# we'll push the task end-time to this queue during `__aexit__`
self._end_time_q = deque()
@property
def is_broken(self):
return self._broken_event.is_set()
async def join(self):
""" Will wait until all waiters have finished """
while self._waiters:
await asyncio.sleep(1)
if not self._broken_evt_wait_fut.done():
# normal operation
if self._release_task and not self._release_task.done():
await self._release_task
self._broken_evt_wait_fut.cancel()
else:
# exceptional operation
if not self._release_task.done():
self._release_task.cancel()
async def __aenter__(self):
self._waiters += 1
try:
# Wait on which happens first: we acquire the semaphore or the rate-limiter breaks
with CancellingTaskCtx(self._sema.acquire()) as acquire_fut:
await asyncio.wait((self._broken_evt_wait_fut, acquire_fut), return_when=asyncio.FIRST_COMPLETED)
if self._broken_evt_wait_fut.done():
raise self.Error("Error while acquiring semaphore") from self._release_worker_exception
finally:
self._waiters -= 1
async def _release_worker(self, sleep_s):
try:
# swapping back/forth at this point is ok because __aexit__ will not swap as it does not yield will detect we're already running
await asyncio.sleep(sleep_s)
now = time.time() # cache as this call is not cheap
# Here we'll release each semaphore that expired its period from when it finished
# We have a loop as an optimization against having multipler timers since the timer may be called later
# than when wanted, and thus we may have multiple semaphores that we can release.
while len(self._end_time_q):
oldest_finished_ts = self._end_time_q[0]
time_since_finished_ts = now - oldest_finished_ts
if time_since_finished_ts >= self._period_s:
# if either of these fail, the ratelimiter will be marked as broken and all current and future acquires will raise
self._end_time_q.popleft()
self._sema.release()
else:
# swapping here is ok for same reason as above
await asyncio.sleep(self._period_s - time_since_finished_ts)
now = time.time() # we need to update time after we sleep
except BaseException as e:
self._logger.exception("Failed while attempting to release semaphores")
self._release_worker_exception = e # must set this before we set the event
self._broken_event.set()
# NOTE: theoretically we could try to "reset" the limiter after flushing the semas
raise
finally:
self._release_task = None # only clear this when we're actually exiting
async def __aexit__(self, exc_type, exc_val, exc_tb):
# NOTE: Even if there's a pending exception we have to assume the call counted
# It's important no yields occur because this and _release_worker modify self._end_time_q. We're
# relying on asyncio behavior of only allowing one task to run at a time as a lock.
try:
# If this fails you'll permanently decrease your available max_rate by one, and if max_rate == 1 deadlock
self._end_time_q.append(time.time())
if not self._release_task:
# If there's already a timer we don't need to register a new one because the existing
# timer will iterate through all the pending events and re-register if they're not yet releasable.
# If this fails you'll deadlock if your max_rate == 1
self._release_task = asyncio.ensure_future(self._release_worker(self._period_s))
except:
self._logger.exception("Error registering rate limiter hit, potential for deadlock!!!")
raise
def _assertRecursiveAlmostEqual(self, first, second, places=None, msg=None, delta=None):
"""Fail if the two objects are unequal as determined by their
difference rounded to the given number of decimal places
(default 7) and comparing to zero, or by comparing that the
between the two objects is more than the given delta.
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most signficant digit).
If the two objects compare equal then they will automatically
compare almost equal.
"""
if type(first) != type(second) and not (isinstance(first, (float, int, complex)) and isinstance(second, (float, int, complex))):
return self.assertEqual(first, second) # will raise mis-matched types
if isinstance(first, (_none_type, str)):
self.assertEqual(first, second)
elif isinstance(first, (float, int, complex)):
self.assertAlmostEqual(first, second, places, msg, delta)
elif isinstance(first, dict):
self.assertEqual(set(first.keys()), set(second.keys())) # will raise keys don't match
for f_k, f_v in first.items():
try:
self.assertRecursiveAlmostEqual(f_v, second[f_k], places, msg, delta)
except Exception as e:
raise Exception("Error with key: {}".format(f_k)) from e
elif isinstance(first, (list, tuple)):
if len(first) != len(second):
self.assertEqual(first, second) # will raise list don't have same length
for idx in range(len(first)):
try:
self.assertRecursiveAlmostEqual(first[idx], second[idx], places, msg, delta)
except Exception as e:
raise Exception("Error with index: {}".format(idx)) from e
else:
assert False # unsupported
# Monkeypatch in method
unittest.TestCase.assertRecursiveAlmostEqual = _assertRecursiveAlmostEqual
class TestRateLimiter(asynctest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logging.basicConfig(level=logging.INFO)
self._logger = logging.getLogger(self.__class__.__name__)
self._rl = None
async def validate_elapsed(self, coro, num_seconds: float or int):
start_s = time.time()
try:
await asyncio.wait_for(coro, num_seconds)
except:
print("Error elapsed: {}".format(time.time() - start_s))
raise
elapsed_s = time.time() - start_s
self.assertAlmostEqual(num_seconds, elapsed_s, delta=0.1)
async def tearDown(self):
await self._rl.join()
@staticmethod
async def acquire(rl2, sleep_s=0):
start = time.time()
async with rl2:
wait_s = time.time() - start
await asyncio.sleep(sleep_s)
return wait_s
async def test_rate_limiter1(self):
# test sequential
self._rl = rl = RateLimiter(3, 2, self._logger)
await asyncio.wait_for(self.acquire(rl), 0.01) # 1
await asyncio.wait_for(self.acquire(rl), 0.01) # 2
await asyncio.wait_for(self.acquire(rl), 0.01) # 3
fut = asyncio.Task(self.acquire(rl)) # 4
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(fut), 0.1) # 4 (fail)
await self.validate_elapsed(fut, 1.95) # 4 (complete)
await asyncio.wait_for(self.acquire(rl), 0.01) # 5
await asyncio.wait_for(self.acquire(rl), 0.01) # 6
fut = asyncio.Task(self.acquire(rl)) # 7
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(fut), 0.01) # 7 (fail)
await self.validate_elapsed(fut, 2) # 7 (complete)
async def test_rate_limiter2(self):
# test parallel
self._rl = rl = RateLimiter(3, 2, self._logger)
await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 0.1) for _ in range(3)]) # 1, 2, 3
fut = asyncio.Task(self.acquire(rl)) # 4
# this one should timeout but count
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(fut), 0.1) # 4 (fail)
# these should take 2 seconds and pass
await asyncio.gather(self.validate_elapsed(fut, 1.95), *[self.validate_elapsed(self.acquire(rl), 1.95) for _ in range(2)])
async def test_rate_limiter3(self):
self._rl = rl = RateLimiter(1, 1, self._logger)
await asyncio.wait_for(self.acquire(rl), 0.01)
await asyncio.sleep(1.1)
await asyncio.wait_for(self.acquire(rl), 0.01)
async def test_rate_limiter4(self):
self._rl = rl = RateLimiter(3, 2, self._logger)
# these won't register the call as having happened 1, 2, 3 seconds after
await asyncio.gather(*[self.validate_elapsed(self.acquire(rl, i), i + 0.1) for i in range(1, 4)])
# we've waited 3s, so 1s one is released, 2nd has 1s wait, 3rd has 2s wait
times = sorted(await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 3) for _ in range(3)]))
self.assertRecursiveAlmostEqual(times, [0, 1, 2], delta=0.1)
async def test_rate_limiter5(self):
self._rl = rl = RateLimiter(3, 2, self._logger)
for i in range(1, 4):
asyncio.ensure_future(self.acquire(rl, i))
# in this scenario we'll have to wait 1 + 2, 2 + 2, 3 + 2 seconds
times = sorted(await asyncio.gather(*[asyncio.wait_for(self.acquire(rl), 5.1) for _ in range(3)]))
self.assertRecursiveAlmostEqual(times, [3, 4, 5], delta=0.1)
# TODO: add test where we break _release_worker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment