Last active
January 15, 2017 16:19
-
-
Save luhn/81e4445582bbaaf9f5e06417ae18c13f to your computer and use it in GitHub Desktop.
Memoize method decorator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
class memoize_method: | |
""" | |
A decorator to memoize a method. The cache is stored as a dictionary | |
inside the object. | |
This functionality is similar to stdlib's :func:`functools.lru_cache`. | |
When using lru_cache on a method, the object will be used as part of the | |
cache key. This means a reference to the object will linger until all | |
related cache entries have been cleared, which will never happen if the | |
cache is unbounded. By storing a cache within each object, this problem | |
is avoided. | |
""" | |
def __init__(self, func): | |
self.func = func | |
self.cache_attr = '_{}_cache'.format(func.__name__) | |
functools.update_wrapper(self, func) | |
def _call(self, obj, *args, **kwargs): | |
key = self._gen_key(args, kwargs) | |
cache = self._get_cache(obj) | |
if key in cache: | |
return cache[key] | |
val = self.func(obj, *args, **kwargs) | |
cache[key] = val | |
return val | |
def _get_cache(self, obj): | |
if hasattr(obj, self.cache_attr): | |
return getattr(obj, self.cache_attr) | |
else: | |
cache = dict() | |
setattr(obj, self.cache_attr, cache) | |
return cache | |
def _gen_key(self, args, kwargs): | |
kwtuple = sum(sorted(kwargs.items()), tuple()) | |
return (tuple(args), kwtuple) | |
def __get__(self, obj, type=None): | |
if obj is None: | |
return self.func | |
partial = functools.partial(self._call, obj) | |
functools.update_wrapper(partial, self.func) | |
return partial | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unittest.mock import Mock, call as Call | |
from revenue.utils import memoize_method | |
def test_memoize(): | |
f = Mock(side_effect=[1, 2, 3, 4]) | |
class MyClass: | |
@memoize_method | |
def m(self, *args, **kwargs): | |
return f(self, *args, **kwargs) | |
o = MyClass() | |
o2 = MyClass() | |
assert o.m(1, 2) == 1 | |
assert o.m(1, 2, a=3) == 2 | |
assert o.m(1, 2) == 1 | |
assert o.m(1, 2, a=3) == 2 | |
assert o.m(1, 2, b=3) == 3 | |
assert o.m(1, 2, a=3) == 2 | |
assert o2.m(1, 2) == 4 | |
assert f.call_args_list == [ | |
Call(o, 1, 2), | |
Call(o, 1, 2, a=3), | |
Call(o, 1, 2, b=3), | |
Call(o2, 1, 2), | |
] | |
def test_memoize_wraps(): | |
class MyClass: | |
@memoize_method | |
def mymethod(self): | |
"My docstring" | |
pass | |
o = MyClass() | |
assert MyClass.mymethod.__name__ == 'mymethod' | |
assert MyClass.mymethod.__doc__ == 'My docstring' | |
assert o.mymethod.__name__ == 'mymethod' | |
assert o.mymethod.__doc__ == 'My docstring' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment