Created
September 14, 2020 22:02
-
-
Save imposeren/9a419b9e4321b372e16e7753ed288f65 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Например: window_rate_limit=500, time_window_seconds = 3600, cooldown_steps=4 (cooldown 4 times during time_window_seconds) | |
# | |
#Храним: | |
#* key=api_key — self explanatory | |
#* last_timestamp — self explanatory | |
#* last_cooldown_timestamp — last time cooldown of avg rate was done. | |
#* current_rate — current average requests rate | |
#* cooldown_amounts — tracks how much cooling down is needed. Defaults to empty list: `[]`. | |
# | |
#При получении запроса: | |
import datetime | |
import unittest | |
from collections import deque | |
from copy import deepcopy | |
from functools import partial | |
class Dummy: | |
# NOTE: for `x+=1` operations its better to use atomic storage operation, | |
# so real class for data representation should probably override __iadd__ | |
# methods, or self class should do everything in a transaction with | |
# locks | |
def __init__(self, name=None, value=None): | |
self.value = value | |
self.name = name | |
def __repr__(self): | |
if name and value: | |
return f'<Dummy {self.name}({self.value!r})>' | |
elif name: | |
return f'<Dummy {self.name} object>' | |
elif value: | |
return f'<Dummy object with value={self.value!r}>' | |
else: | |
return f'<Dummy object with id={id(self)}>' | |
class ClassAttr(type): | |
def __getattr__(cls, name): | |
if name == 'Value': | |
return partial(Dummy, "cooldown_Value") | |
raise AttributeError(name) | |
class CD: | |
not_fitting = Dummy('cooldown_not_fitting') | |
#: Rate is reduced by requests not fitting in window. | |
not_fitting_limited = Dummy('cooldown_not_fitting_limited') | |
#: Rate is reduced by requests not fitting in window, but not | |
#: faster than `window_rate_limit` per `time_window_seconds` | |
class Backend: | |
def __init__(self, time_window_seconds=3600, window_rate_limit=500, | |
cooldown_steps=4, rate_cooldown_speed=CD.not_fitting): | |
self.last_data = Dummy() | |
self.last_data.key = 'ignored in demo' | |
self.last_data.last_timestamp = None | |
self.last_data.last_cooldown_timestamp = None | |
self.last_data.current_rate = 0 | |
self.last_data.cooldown_amounts = [] | |
self.last_data.rate_limit_exceeded = False | |
self.limits_data = Dummy() | |
self.limits_data.window_rate_limit = window_rate_limit | |
self.limits_data.time_window_seconds = time_window_seconds | |
self.limits_data.cooldown_steps = cooldown_steps | |
self.limits_data.cooldown_speed = rate_cooldown_speed | |
def get_last_data(self, request): | |
return self.last_data | |
def get_throttling(self, request): | |
return self.limits_data | |
def copy(self, data_obj): | |
return deepcopy(data_obj) | |
def update_data(self, request, data_obj): | |
self.last_data = data_obj | |
data_obj.rate_limit_exceeded = ( | |
data_obj.current_rate > self.limits_data.window_rate_limit | |
) | |
def update_rate(self, ts_now=None, delay=None): | |
if not ts_now: | |
ts_now = datetime.datetime.now() | |
if delay: | |
if isinstance(delay, (int, float)): | |
delay = datetime.timedelta(seconds=delay) | |
ts_now += delay | |
request = None | |
limits_data = self.get_throttling(request) | |
cd_steps = limits_data.cooldown_steps | |
prev_data = self.get_last_data(request) | |
new_data = self.copy(prev_data) | |
# ALWAYS increment rate by 1 because single request is processed. | |
new_data.current_rate += 1 | |
new_data.last_timestamp = ts_now | |
if not prev_data.last_cooldown_timestamp: | |
inspected_period = 0 | |
window_progress = 0 | |
new_data.last_cooldown_timestamp = ts_now | |
new_data.cooldown_amounts = [0 for __ in range(cd_steps+1)] | |
else: | |
inspected_period = (ts_now - prev_data.last_cooldown_timestamp).total_seconds() | |
window_progress = inspected_period/limits_data.time_window_seconds | |
new_data.cooldown_amounts = deque(new_data.cooldown_amounts, maxlen=cd_steps+1) | |
cooldowns_queue = new_data.cooldown_amounts | |
# Perform cooldown... | |
# * on progress > (1+2*cd_steps)/cd_steps (e.g. >2.25): reset cooldowns queue and set | |
# rate to 1 | |
# * on progress > (2*cd_steps)/cd_steps (e.g. >2.0): reduce rate by values of first | |
# five cooldowns in queue, and shift queue 3 positions to the left, and | |
# append 3 zeros | |
# ... | |
# * on progress > (cd_steps+1)/cd_steps (e.g. >1.25) ... | |
# | |
# NOTE: frac_top is a numerator of fraction, and frac_bottom is denominator of | |
# fraction. (Avoiding "numerator" because `enumerate` function is used). | |
frac_bottom = cd_steps | |
if window_progress > (cd_steps+1)/cd_steps: | |
for overkill in range(1+cd_steps, 0, -1): | |
frac_top = overkill + cd_steps | |
if window_progress > frac_top/frac_bottom: | |
new_data.current_rate -= sum( | |
cooldowns_queue.popleft() for __ in range(overkill) | |
) | |
cooldowns_queue.extend(0 for __ in range(overkill)) | |
window_progress = 1.0 | |
ts_shift = datetime.timedelta( | |
seconds=window_progress*limits_data.time_window_seconds, | |
microseconds=-1, | |
) | |
new_data.last_cooldown_timestamp = ts_now - ts_shift | |
break | |
else: | |
raise RuntimeError('Not possible!') | |
# For different `window_progress`eses, increment different items of | |
# `cooldowns_queue`. For example: >4/4 → last item, >3/4 → pre-last, ... | |
# | |
for i, frac_top in enumerate(range(limits_data.cooldown_steps, 0, -1), 1): | |
cd_index = -i | |
if window_progress >= frac_top/frac_bottom: | |
cooldowns_queue[cd_index] += 1 | |
break | |
else: | |
# No break → window_progress<=1/4 | |
cooldowns_queue[0] += 1 | |
self.update_data(request, new_data) | |
#return new_data.__dict__ | |
return new_data | |
class TestThrottleIdea(unittest.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.backend = Backend(3600, 500, 4) | |
self.first_result = self.backend.update_rate() | |
self.now = self.first_result.last_timestamp | |
def test_simple(self): | |
now = self.now | |
backend = self.backend | |
limits = backend.limits_data | |
per_item_delay = limits.time_window_seconds / limits.window_rate_limit | |
for i in range(limits.window_rate_limit-5): | |
delay = i * per_item_delay | |
latest = backend.update_rate(now, delay) | |
self.assertEqual(latest.current_rate, 500+1-5) | |
self.assertFalse(latest.rate_limit_exceeded) | |
last_delay = delay | |
latest = backend.update_rate(now, delay+1) | |
self.assertEqual(latest.current_rate, 497) | |
self.assertFalse(latest.rate_limit_exceeded) | |
for i in range(300): | |
delay = last_delay + i*4/300 | |
latest = backend.update_rate(now, delay) | |
self.assertEqual(latest.current_rate, 797) | |
self.assertTrue(latest.rate_limit_exceeded) | |
latest = backend.update_rate(now, int(3600*1.26)) | |
self.assertEqual(latest.current_rate, 797-125) | |
self.assertTrue(latest.rate_limit_exceeded) | |
def test_multiple_cooldown(self): | |
now = self.now | |
backend = self.backend | |
cd_period = 3600 / 4 | |
period_reqs = ( | |
100, | |
133, | |
222, | |
333, | |
444, | |
555, | |
30, | |
42, | |
1, | |
) | |
dataset = ( | |
# time interval, num_requests, expected_rate, exceeded | |
(cd_period/2, period_reqs[0], 1+period_reqs[0], False), | |
(cd_period/2-1, period_reqs[1], 1+sum(period_reqs[:2]), False), | |
(cd_period+1, period_reqs[2], 1+sum(period_reqs[:3]), False), | |
(cd_period, period_reqs[3], 1+sum(period_reqs[:4]), True), | |
(cd_period, period_reqs[4], 1+sum(period_reqs[:5]), True), | |
(cd_period-1, period_reqs[5], 1+sum(period_reqs[:6]), True), | |
# Cooling down first time | |
( | |
cd_period+1, | |
period_reqs[6], | |
1+sum(period_reqs[:7])-sum(period_reqs[:2])-1, | |
True, | |
), | |
# Cooling down second time | |
( | |
cd_period, | |
period_reqs[7], | |
1+sum(period_reqs[:8])-sum(period_reqs[:3]), | |
True, | |
), | |
# Cool for long time to reduce rate by multiple periods at once, | |
# also do just a single request for "increased complexity" | |
( | |
cd_period*3, | |
period_reqs[8], | |
1+sum(period_reqs[:9])-sum(period_reqs[:6])-1, | |
False, # Rate is good now! | |
), | |
) | |
for i, (interval, num_requests, expected_rate, problem) in enumerate(dataset): | |
with self.subTest(step=i, interval=interval, | |
num_requests=num_requests): | |
per_item_delay = interval / num_requests | |
for __ in range(num_requests): | |
latest = backend.update_rate(now, per_item_delay) | |
self.assertEqual( | |
latest.last_timestamp, | |
now+datetime.timedelta(seconds=per_item_delay), | |
) | |
now = latest.last_timestamp | |
self.assertEqual(latest.current_rate, expected_rate) | |
self.assertEqual(latest.rate_limit_exceeded, problem) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment