Last active
January 21, 2021 16:15
-
-
Save cipri-tom/2350395280537544ed9a2991747ea412 to your computer and use it in GitHub Desktop.
DiskCache sqlerror when throttled / memoized
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################## | |
### Overriding DiskCache's decorators to make them pickle-able ### | |
################################################################## | |
import diskcache as dc | |
from functools import update_wrapper | |
import time | |
ENOVAL = dc.core.ENOVAL | |
class Memoized: | |
NOTICE = ( | |
'NOTE: This function is memoised.' | |
'\n The results for a given combination of parameters are kept' | |
'\n for {expire} seconds in the cache at {cache_dest}' | |
) | |
def __init__(self, func, cache, name=None, typed=False, expire=None, tag=None) -> None: | |
"""Memoizing cache decorator. | |
Decorator to wrap callable with memoizing function using cache. | |
Repeated calls with the same arguments will lookup result in cache and | |
avoid function evaluation. | |
This is the class version of the [diskcache's own implementation](#1). | |
Class version is required for having a pickle-compatible callable | |
#1: https://github.com/grantjenks/python-diskcache/blob/c4ba1f78bb8494bcf6aba9d7d1c3aa49a1093508/diskcache/core.py#L1782 | |
""" | |
self.base = dc.core.full_name(func) if name is None else name | |
self.base = (self.base,) # 1-tuple due to use in args_to_key | |
self.func = func | |
# update name, qualname etc, but do not copy dict ! https://stackoverflow.com/q/53602825 | |
update_wrapper(self, func, updated=()) | |
notice = self.NOTICE.format(expire=expire, cache_dest=cache.directory) | |
if self.__doc__: | |
self.__doc__ += '\n\n' + notice | |
else: | |
self.__doc__ = notice | |
self.cache = cache | |
self.typed = typed | |
self.expire = expire | |
self.tag = tag | |
def __cache_key__(self, *args, **kwargs): | |
"Make key for cache given function arguments." | |
return dc.core.args_to_key(self.base, args, kwargs, self.typed) | |
def __call__(self, *args, **kwargs): | |
"Wrapper for callable to cache arguments and return values." | |
key = self.__cache_key__(*args, **kwargs) | |
result = self.cache.get(key, default=ENOVAL, retry=True) | |
if result is ENOVAL: | |
result = self.func(*args, **kwargs) | |
if self.expire is None or self.expire > 0: | |
self.cache.set(key, result, self.expire, tag=self.tag, retry=True) | |
return result | |
class Throttled: | |
NOTICE = ( | |
'NOTE: This function is throttled!' | |
'\n At most {count} calls can be made every {seconds} seconds.' | |
'\n It synchronises using the cache at {cache_dest}.' | |
) | |
def __init__(self, func, cache, count, seconds, name=None, expire=None, tag=None, | |
time_func=time.time, sleep_func=time.sleep): | |
"""Decorator to throttle calls to function. | |
This is exact implementation of DiskCache's own [`throttle()`](#1), | |
but in a class object, which is required for having pickle-compatible decorator. | |
The only modification is that we don't allow starting with two-times the rate (#2) | |
#1 https://github.com/grantjenks/python-diskcache/blob/c4ba1f78bb8494bcf6aba9d7d1c3aa49a1093508/diskcache/recipes.py#L227 | |
#2 https://stackoverflow.com/a/6415181/786559 | |
""" | |
self.cache = cache | |
self.count = count | |
self.seconds = seconds | |
self.expire = expire | |
self.tag = tag | |
self.time_func = time_func | |
self.sleep_func = sleep_func | |
self.func = func | |
# update name, qualname etc, but do not copy dict ! https://stackoverflow.com/q/53602825 | |
update_wrapper(self, func, updated=()) | |
notice = self.NOTICE.format(count=count, seconds=seconds, cache_dest=cache.directory) | |
if self.__doc__: | |
self.__doc__ += '\n\n' + notice | |
else: | |
self.__doc__ = notice | |
self.rate = self.count / float(self.seconds) | |
self.key = dc.core.full_name(func) if name is None else name | |
self.now = self.time_func() | |
# in the original code, there is an allowance of `count` to start with | |
# but this results in double the calls for the first interval | |
# we set it to zero so that it is initialised with the time of the first call | |
self.cache.set(self.key, (self.now, 0), expire=self.expire, tag=self.tag, retry=True) | |
def __call__(self, *args, **kwargs): | |
while True: | |
with self.cache.transact(retry=True): | |
last, tally = self.cache.get(self.key) | |
now = self.time_func() | |
tally += (now - last) * self.rate | |
delay = 0 | |
if tally > self.count: | |
self.cache.set(self.key, (now, self.count - 1), self.expire) | |
elif tally >= 1: | |
self.cache.set(self.key, (now, tally - 1), self.expire) | |
else: | |
delay = (1 - tally) / self.rate | |
if delay: | |
self.sleep_func(delay) | |
else: | |
break | |
return self.func(*args, **kwargs) | |
################################################################## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import diskcache as dc | |
import time | |
import random | |
from dc_utils import Throttled, Memoized | |
class API: | |
def __init__(self, cache_dir, expire=7*24*60*60) -> None: | |
self.expire = expire | |
self.cache = dc.Cache(cache_dir) | |
self.__setup_memoization_throttling() | |
def __setup_memoization_throttling(self): | |
""" | |
The following lines are equivalent to these decorators: | |
@memoized(expire=0) | |
@throttled(count=1, seconds=2) | |
@memoized(expire=duration) | |
def _check_online(self): | |
""" | |
# inner memoization -- read/write | |
check_memoized_rw = Memoized(self._check_online, self.cache, expire=self.expire) | |
# throttle the API -- API spec is max 30 calls per 60s, | |
# but they average somehow weirdly and large bursts result errors with API Limit Exceeded | |
# hence 1 call / 2s always respects that limit | |
# since we are mocking the API, we can increase that -- the error persists | |
check_30per_min = Throttled(check_memoized_rw, self.cache, count=20, seconds=2) | |
# outer memoization -- read-only | |
check_memoized_ro = Memoized(check_30per_min, self.cache, expire=0) | |
self._check_online = check_memoized_ro | |
def _check_online(self, n): | |
# here we would query the API using requests | |
time.sleep(random.random()) | |
return n | |
def get_n(self, n): | |
"""Retrieve data associated with given SIREN number | |
These calls are memoized and throttled! See `__init__` | |
""" | |
# we need this indirection to avoid dead-locking when used in parallel | |
return self._check_online(n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dummy_api import API | |
from concurrent.futures import ProcessPoolExecutor | |
import sys | |
import numpy as np | |
from random import choices | |
mapping = None | |
def init_worker(): | |
global mapping | |
mapping = {'my_api': API('./api-cache')} | |
def check(n, api_name): | |
api = mapping[api_name] | |
return (api.get_n(n), api) | |
def run_serial(items): | |
init_worker() | |
check_names = ['my_api'] * len(items) | |
results = list(map(check, items, check_names)) | |
return results | |
def run_parallel(num_threads, items): | |
check_names = ['my_api'] * len(items) | |
with ProcessPoolExecutor(num_threads, initializer=init_worker) as pool: | |
results = pool.map(check, items, check_names) | |
return results | |
if __name__ == "__main__": | |
threads = int(sys.argv[1]) | |
nums = 1800 | |
tests = 60000 | |
ws = np.abs(np.random.normal(0.5, 0.3, size=nums)) # approx normally distributed items | |
items = choices(range(nums), k=tests, weights=ws) | |
if threads == 1: | |
results = run_serial(items) | |
else: | |
results = run_parallel(threads, items) | |
if len(sys.argv) > 2: | |
# use the returned `api` from the main thread | |
with open('/tmp/dc_results.out', 'w') as f: | |
for result, api in results: | |
print(result, file=f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks !
Indeed, it is very different ! I had removed too many comments when stripping down the example 😅. Although the specification is 30 calls / min, it errors out when we have bursts, so I had to reduce it to 1 call every 2s.
I've increased it now to 20 calls / 2s, and this still errors for me, while filling the cache much faster.