Skip to content

Instantly share code, notes, and snippets.

@dlebech
Created March 20, 2016 16:51
Show Gist options
  • Save dlebech/c16a34f735c0c4e9b604 to your computer and use it in GitHub Desktop.
Save dlebech/c16a34f735c0c4e9b604 to your computer and use it in GitHub Desktop.
Python LRU cache that works with coroutines (asyncio)
"""Global LRU caching utility. For that little bit of extra speed.
The caching utility provides a single wrapper function that can be used to
provide a bit of extra speed for some often used function. The cache is an LRU
cache including a key timeout.
Usage::
import cache
@cache.memoize
def myfun(x, y):
return x + y
Also support asyncio coroutines::
@cache.memoize
async def myfun(x, y):
x_result = await fetch_x(x)
return x_result + y
The cache can be manually cleared with `myfun.cache.clear()`
"""
import asyncio
from functools import wraps
from lru import LRUCacheDict
__all__ = ['memoize']
def _wrap_coroutine_storage(cache_dict, key, future):
async def wrapper():
val = await future
cache_dict[key] = val
return val
return wrapper()
def _wrap_value_in_coroutine(val):
async def wrapper():
return val
return wrapper()
def memoize(f):
"""An in-memory cache wrapper that can be used on any function, including
coroutines.
"""
__cache = LRUCacheDict(max_size=256, expiration=60)
@wraps(f)
def wrapper(*args, **kwargs):
# Simple key generation. Notice that there are no guarantees that the
# key will be the same when using dict arguments.
key = f.__module__ + '#' + f.__name__ + '#' + repr((args, kwargs))
try:
val = __cache[key]
if asyncio.iscoroutinefunction(f):
return _wrap_value_in_coroutine(val)
return val
except KeyError:
val = f(*args, **kwargs)
if asyncio.iscoroutine(val):
# If the value returned by the function is a coroutine, wrap
# the future in a new coroutine that stores the actual result
# in the cache.
return _wrap_coroutine_storage(__cache, key, val)
# Otherwise just store and return the value directly
__cache[key] = val
return val
return wrapper
"""Tests the caching module."""
import asyncio
import unittest
import cache
called = 0
@cache.memoize
def wrapped():
global called
called += 1
return 10
class MemoizeClass(object):
cls_called = 0
cls_async_called = 0
@classmethod
@cache.memoize
def my_class_fun(cls):
cls.cls_called += 1
return 20
@classmethod
@cache.memoize
async def my_async_classmethod(cls):
cls.cls_async_called += 1
return 40
def __init__(self):
self.called = 0
@cache.memoize
def my_fun(self):
self.called += 1
return 30
@cache.memoize
async def my_async_fun(self):
self.called += 1
return 50
class TestMemoize(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
def test_memoize_fun(self):
"""It should work for a module level method"""
self.assertEqual(called, 0)
val = wrapped()
self.assertEqual(val, 10)
self.assertEqual(called, 1)
val = wrapped()
self.assertEqual(val, 10)
self.assertEqual(called, 1)
def test_memoize_class_method(self):
"""It should work for a classmethod"""
self.assertEqual(MemoizeClass.cls_called, 0)
val = MemoizeClass.my_class_fun()
self.assertEqual(val, 20)
self.assertEqual(MemoizeClass.cls_called, 1)
val = MemoizeClass.my_class_fun()
self.assertEqual(val, 20)
self.assertEqual(MemoizeClass.cls_called, 1)
def test_memoize_instance_method(self):
"""It should work for an instance method"""
mc = MemoizeClass()
self.assertEqual(mc.called, 0)
val = mc.my_fun()
self.assertEqual(val, 30)
self.assertEqual(mc.called, 1)
val = mc.my_fun()
self.assertEqual(val, 30)
self.assertEqual(mc.called, 1)
def test_memoize_async_classmethod(self):
"""It should work with an async coroutine as classmethod."""
self.assertEqual(MemoizeClass.cls_async_called, 0)
async def go():
val_fut1 = await MemoizeClass.my_async_classmethod()
val_fut2 = await MemoizeClass.my_async_classmethod()
self.assertEqual(val_fut1, 40)
self.assertEqual(val_fut2, 40)
self.loop.run_until_complete(go())
self.assertEqual(MemoizeClass.cls_async_called, 1)
def test_memoize_async(self):
"""It should work with an async coroutine instance method."""
mc = MemoizeClass()
self.assertEqual(mc.called, 0)
async def go():
val_fut1 = await mc.my_async_fun()
val_fut2 = await mc.my_async_fun()
self.assertEqual(val_fut1, 50)
self.assertEqual(val_fut2, 50)
self.loop.run_until_complete(go())
self.assertEqual(mc.called, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment