Skip to content

Instantly share code, notes, and snippets.

@ambv
Last active May 26, 2016 08:22
Show Gist options
  • Save ambv/109cc45a8a905aac72fbfce0b03f03a9 to your computer and use it in GitHub Desktop.
Save ambv/109cc45a8a905aac72fbfce0b03f03a9 to your computer and use it in GitHub Desktop.
Per-instance memoization in Python. The per-instance lazy binder is thread-safe.
from functools import lru_cache
import threading
import time
def per_instance(factory, *factory_args, **factory_kwargs):
"""Applies the given decorator on a per-instance basis."""
def lazy_binder(method):
"""Replaces the method just in time when it is first invoked."""
lock = threading.Lock()
def _lazy_binder(self, *args, **kwargs):
with lock:
current_name = getattr(self, method.__name__).__name__
if current_name == '_lazy_binder':
bound_method = method.__get__(self, self.__class__)
decorator = factory(*factory_args, **factory_kwargs)
decorated_bound_method = decorator(bound_method)
setattr(self, method.__name__, decorated_bound_method)
return getattr(self, method.__name__)(*args, **kwargs)
return _lazy_binder
return lazy_binder
class C:
# NOTE: this is still not thread-safe because of race conditions on LRU
# cache. Specifically, this doesn't prevent thundering herds and
# unpredictable results due to .cache_clear() calls from other threads.
@per_instance(lru_cache, maxsize=128, typed=False)
def heavy_method(self):
time.sleep(3)
return time.time()
def slow_binding(wait):
time.sleep(wait)
return lru_cache()
class D:
# This is to test that lazy binding is thundering-herd free.
@per_instance(slow_binding, 3)
def heavy_method(self):
return time.time()
class TestDPerInstance(threading.Thread):
d1 = D()
d2 = D()
d3 = D()
def run(self):
self.d1.heavy_method()
self.d1.heavy_method()
self.d1.heavy_method()
self.d2.heavy_method()
self.d2.heavy_method()
self.d2.heavy_method()
self.d3.heavy_method()
self.d3.heavy_method()
self.d3.heavy_method()
def main():
NUM_THREADS = 1000
threads = [TestDPerInstance() for i in range(NUM_THREADS)]
t1 = time.time()
for thread in threads:
thread.start()
t2 = time.time()
for thread in threads:
thread.join()
assert t2 - t1 < 11, "There's a thundering herd on lazy binding."
assert TestDPerInstance.d1.heavy_method.cache_info().hits == 3 * NUM_THREADS - 1
assert TestDPerInstance.d2.heavy_method.cache_info().hits == 3 * NUM_THREADS - 1
assert TestDPerInstance.d3.heavy_method.cache_info().hits == 3 * NUM_THREADS - 1
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment