Skip to content

Instantly share code, notes, and snippets.

@luhn
Last active January 15, 2017 16:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luhn/81e4445582bbaaf9f5e06417ae18c13f to your computer and use it in GitHub Desktop.
Save luhn/81e4445582bbaaf9f5e06417ae18c13f to your computer and use it in GitHub Desktop.
Memoize method decorator
import functools
class memoize_method:
"""
A decorator to memoize a method. The cache is stored as a dictionary
inside the object.
This functionality is similar to stdlib's :func:`functools.lru_cache`.
When using lru_cache on a method, the object will be used as part of the
cache key. This means a reference to the object will linger until all
related cache entries have been cleared, which will never happen if the
cache is unbounded. By storing a cache within each object, this problem
is avoided.
"""
def __init__(self, func):
self.func = func
self.cache_attr = '_{}_cache'.format(func.__name__)
functools.update_wrapper(self, func)
def _call(self, obj, *args, **kwargs):
key = self._gen_key(args, kwargs)
cache = self._get_cache(obj)
if key in cache:
return cache[key]
val = self.func(obj, *args, **kwargs)
cache[key] = val
return val
def _get_cache(self, obj):
if hasattr(obj, self.cache_attr):
return getattr(obj, self.cache_attr)
else:
cache = dict()
setattr(obj, self.cache_attr, cache)
return cache
def _gen_key(self, args, kwargs):
kwtuple = sum(sorted(kwargs.items()), tuple())
return (tuple(args), kwtuple)
def __get__(self, obj, type=None):
if obj is None:
return self.func
partial = functools.partial(self._call, obj)
functools.update_wrapper(partial, self.func)
return partial
from unittest.mock import Mock, call as Call
from revenue.utils import memoize_method
def test_memoize():
f = Mock(side_effect=[1, 2, 3, 4])
class MyClass:
@memoize_method
def m(self, *args, **kwargs):
return f(self, *args, **kwargs)
o = MyClass()
o2 = MyClass()
assert o.m(1, 2) == 1
assert o.m(1, 2, a=3) == 2
assert o.m(1, 2) == 1
assert o.m(1, 2, a=3) == 2
assert o.m(1, 2, b=3) == 3
assert o.m(1, 2, a=3) == 2
assert o2.m(1, 2) == 4
assert f.call_args_list == [
Call(o, 1, 2),
Call(o, 1, 2, a=3),
Call(o, 1, 2, b=3),
Call(o2, 1, 2),
]
def test_memoize_wraps():
class MyClass:
@memoize_method
def mymethod(self):
"My docstring"
pass
o = MyClass()
assert MyClass.mymethod.__name__ == 'mymethod'
assert MyClass.mymethod.__doc__ == 'My docstring'
assert o.mymethod.__name__ == 'mymethod'
assert o.mymethod.__doc__ == 'My docstring'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment