Skip to content

Instantly share code, notes, and snippets.

@mshonichev
Forked from Susensio/numpy_lru_cache.md
Created April 13, 2022 11:29
Show Gist options
  • Save mshonichev/11878623387640ed7d6d9dacefbc5438 to your computer and use it in GitHub Desktop.
Save mshonichev/11878623387640ed7d6d9dacefbc5438 to your computer and use it in GitHub Desktop.
Make function of numpy array cacheable

How to cache slow functions with numpy.array as function parameter on Python

TL;DR

from numpy_lru_cache_decorator import np_cache

@np_cache()
def function(array):
    ...

Explanation

Sometimes processing numpy arrays can be slow, even more if we are doing image analysis. Simply using functools.lru_cache won't work because numpy.array is mutable and not hashable. This workaround allows caching functions that take an arbitrary numpy.array as first parameter, other parameters are passed as is. Decorator accepts lru_cache standard parameters (maxsize=128, typed=False).

Example:

>>> array = np.array([[1, 2, 3], [4, 5, 6]])

>>> @np_cache(maxsize=256)
... def multiply(array, factor):
...     print("Calculating...")
...     return factor*array

>>> product = multiply(array, 2)
Calculating...
>>> product
array([[ 2,  4,  6],
       [ 8, 10, 12]])

>>> multiply(array, 2)
array([[ 2,  4,  6],
       [ 8, 10, 12]])

Warning: about lru_cache decorator caveats

User must be very careful when mutable objects (list, dict, numpy.array...) are returned. A reference to the same object in memory is returned each time from cache and not a copy. Then, if this object is modified, the cache itself looses its validity.

Example of this caveat:

>>> array = np.array([1, 2, 3])

>>> @np_cache()
... def to_list(array):
...     print("Calculating...")
...     return array.tolist()

>>> result = to_list(array)
Calculating...
>>> result
[1, 2, 3]

>>> result.append("this shouldn't be here")  # WARNING, DO NOT do this
>>> result
[1, 2, 3, "this shouldn't be here"]

>>> new_result = to_list(array)
>>> result
[1, 2, 3, "this shouldn't be here"]  # CACHE BROKEN!!

To avoid this mutability problem, the usual approaches must be followed. In this case, either list(result) or result[:] will create a (shallow) copy. If result were a nested list, deepcopy must be used. For numpy.array, array.copy() must be used, as neither array[:] nor numpy.array(array) will make a copy.

from functools import lru_cache, wraps
import numpy as np
def np_cache(*args, **kwargs):
"""LRU cache implementation for functions whose FIRST parameter is a numpy array
>>> array = np.array([[1, 2, 3], [4, 5, 6]])
>>> @np_cache(maxsize=256)
... def multiply(array, factor):
... print("Calculating...")
... return factor*array
>>> multiply(array, 2)
Calculating...
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply(array, 2)
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply.cache_info()
CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
"""
def decorator(function):
@wraps(function)
def wrapper(np_array, *args, **kwargs):
hashable_array = array_to_tuple(np_array)
return cached_wrapper(hashable_array, *args, **kwargs)
@lru_cache(*args, **kwargs)
def cached_wrapper(hashable_array, *args, **kwargs):
array = np.array(hashable_array)
return function(array, *args, **kwargs)
def array_to_tuple(np_array):
"""Iterates recursivelly."""
try:
return tuple(array_to_tuple(_) for _ in np_array)
except TypeError:
return np_array
# copy lru_cache attributes over too
wrapper.cache_info = cached_wrapper.cache_info
wrapper.cache_clear = cached_wrapper.cache_clear
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment