Last active
December 29, 2021 16:02
-
-
Save peterbe/fd6ffc23325df849b27c549e769ce570 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
from functools import wraps | |
from django.core.cache import cache | |
from django.utils.encoding import force_text, force_bytes | |
def cache_memoize( | |
timeout, | |
prefix='', | |
args_rewrite=None, | |
hit_callable=None, | |
miss_callable=None, | |
store_result=True, | |
): | |
"""Decorator for memoizing function calls. | |
:arg int time: Number of seconds to store the result if not None | |
:arg string prefix: If you want to assure you don't clash with other keys. | |
:arg function args_rewrite: Callable that rewrites the args first useful | |
if your function needs nontrivial types but you know a simple way to | |
re-represent them for the sake of the cache key. | |
:arg function hit_callable: Gets executed if key was in cache. | |
:arg function miss_callable: Gets executed if key was *not* in cache. | |
:arg bool store_result: If you know the result is not important, just | |
that the cache blocked it from running repeatedly, set this to False. | |
Usage:: | |
@cache_memoize( | |
300, # 5 min | |
args_rewrite=lambda user: user.email, | |
hit_callable=lambda: print("Cache hit!"), | |
miss_callable=lambda: print("Cache miss :("), | |
) | |
def hash_user_email(user): | |
dk = hashlib.pbkdf2_hmac('sha256', user.email, b'salt', 100000) | |
return binascii.hexlify(dk) | |
Or, when you don't actually need the result, useful if you know it's not | |
valuable to store the execution result:: | |
@cache_memoize( | |
300, # 5 min | |
store_result=False, | |
) | |
def send_email(email): | |
somelib.send(email, subject="You rock!", ...) | |
Also, whatever you do where things get cached, you can undo that. | |
For example:: | |
@cache_memoize(100) | |
def callmeonce(arg1): | |
print(arg1) | |
callmeonce('peter') # will print 'peter' | |
callmeonce('peter') # nothing printed | |
callmeonce.invalidate('peter') | |
callmeonce('peter') # will print 'peter' | |
Suppose you know for good reason you want to bypass the cache and | |
really let the decorator let you through you can set one extra | |
keyword argument called `_refresh`. For example:: | |
@cache_memoize(100) | |
def callmeonce(arg1): | |
print(arg1) | |
callmeonce('peter') # will print 'peter' | |
callmeonce('peter') # nothing printed | |
callmeonce('peter', _refresh=True) # will print 'peter' | |
""" | |
if args_rewrite is None: | |
def noop(*args): | |
return args | |
args_rewrite = noop | |
def decorator(func): | |
def _make_cache_key(*args, **kwargs): | |
cache_key = ':'.join( | |
[force_text(x) for x in args_rewrite(*args)] + | |
[force_text(f'{k}={v}') for k, v in kwargs.items()] | |
) | |
return hashlib.md5(force_bytes( | |
'cache_memoize' + prefix + cache_key | |
)).hexdigest() | |
@wraps(func) | |
def inner(*args, **kwargs): | |
refresh = kwargs.pop('_refresh', False) | |
cache_key = _make_cache_key(*args, **kwargs) | |
if refresh: | |
result = None | |
else: | |
result = cache.get(cache_key) | |
if result is None: | |
result = func(*args, **kwargs) | |
if not store_result: | |
# Then the result isn't valuable/important to store but | |
# we want to store something. Just to remember that | |
# it has be done. | |
cache.set(cache_key, True, timeout) | |
elif result is not None: | |
cache.set(cache_key, result, timeout) | |
if miss_callable: | |
miss_callable(*args, **kwargs) | |
elif hit_callable: | |
hit_callable(*args, **kwargs) | |
return result | |
def invalidate(*args, **kwargs): | |
cache_key = _make_cache_key(*args, **kwargs) | |
cache.delete(cache_key) | |
inner.invalidate = invalidate | |
return inner | |
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_cache_memoize(): | |
calls_made = [] | |
@decorators.cache_memoize(10) | |
def runmeonce(a, b, k='bla'): | |
calls_made.append((a, b, k)) | |
return '{} {} {}'.format(a, b, k) # sample implementation | |
runmeonce(1, 2) | |
runmeonce(1, 2) | |
assert len(calls_made) == 1 | |
runmeonce(1, 3) | |
assert len(calls_made) == 2 | |
# should work with most basic types | |
runmeonce(1.1, 'foo') | |
runmeonce(1.1, 'foo') | |
assert len(calls_made) == 3 | |
# even more "advanced" types | |
runmeonce(1.1, 'foo', k=list('åäö')) | |
runmeonce(1.1, 'foo', k=list('åäö')) | |
assert len(calls_made) == 4 | |
# And shouldn't be a problem even if the arguments are really long | |
runmeonce('A' * 200, 'B' * 200, {'C' * 100: 'D' * 100}) | |
assert len(calls_made) == 5 | |
# different prefixes | |
@decorators.cache_memoize(10, prefix='first') | |
def foo(value): | |
calls_made.append(value) | |
return 'ho' | |
@decorators.cache_memoize(10, prefix='second') | |
def bar(value): | |
calls_made.append(value) | |
return 'ho' | |
foo('hey') | |
bar('hey') | |
assert len(calls_made) == 7 | |
# Test when you don't care about the result | |
@decorators.cache_memoize(10, store_result=False, prefix='different') | |
def returnnothing(a, b, k='bla'): | |
calls_made.append((a, b, k)) | |
# note it returns None | |
returnnothing(1, 2) | |
returnnothing(1, 2) | |
assert len(calls_made) == 8 | |
def test_cache_memoize_refresh(): | |
calls_made = [] | |
@decorators.cache_memoize(10) | |
def runmeonce(a): | |
calls_made.append(a) | |
return a * 2 | |
runmeonce(10) | |
assert len(calls_made) == 1 | |
runmeonce(10) | |
assert len(calls_made) == 1 | |
runmeonce(10, _refresh=True) | |
assert len(calls_made) == 2 | |
def test_cache_memoize_hit_miss_callables(): | |
hits = [] | |
misses = [] | |
calls_made = [] | |
def hit_callable(arg): | |
hits.append(arg) | |
def miss_callable(arg): | |
misses.append(arg) | |
@decorators.cache_memoize( | |
10, | |
hit_callable=hit_callable, | |
miss_callable=miss_callable, | |
) | |
def runmeonce(arg): | |
calls_made.append(arg) | |
return arg * 2 | |
result = runmeonce(100) | |
assert result == 200 | |
assert len(calls_made) == 1 | |
assert len(hits) == 0 | |
assert len(misses) == 1 | |
result = runmeonce(100) | |
assert result == 200 | |
assert len(calls_made) == 1 | |
assert len(hits) == 1 | |
assert len(misses) == 1 | |
result = runmeonce(100) | |
assert result == 200 | |
assert len(calls_made) == 1 | |
assert len(hits) == 2 | |
assert len(misses) == 1 | |
result = runmeonce(200) | |
assert result == 400 | |
assert len(calls_made) == 2 | |
assert len(hits) == 2 | |
assert len(misses) == 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment