Skip to content

Instantly share code, notes, and snippets.

@Susensio
Last active October 26, 2023 16:10
Show Gist options
  • Save Susensio/61f4fee01150caaac1e10fc5f005eb75 to your computer and use it in GitHub Desktop.
Save Susensio/61f4fee01150caaac1e10fc5f005eb75 to your computer and use it in GitHub Desktop.
Make function of numpy array cacheable

How to cache slow functions with numpy.array as function parameter on Python

TL;DR

from numpy_lru_cache_decorator import np_cache

@np_cache()
def function(array):
    ...

Explanation

Sometimes processing numpy arrays can be slow, even more if we are doing image analysis. Simply using functools.lru_cache won't work because numpy.array is mutable and not hashable. This workaround allows caching functions that take an arbitrary numpy.array as first parameter, other parameters are passed as is. Decorator accepts lru_cache standard parameters (maxsize=128, typed=False).

Example:

>>> array = np.array([[1, 2, 3], [4, 5, 6]])

>>> @np_cache(maxsize=256)
... def multiply(array, factor):
...     print("Calculating...")
...     return factor*array

>>> product = multiply(array, 2)
Calculating...
>>> product
array([[ 2,  4,  6],
       [ 8, 10, 12]])

>>> multiply(array, 2)
array([[ 2,  4,  6],
       [ 8, 10, 12]])

Warning: about lru_cache decorator caveats

User must be very careful when mutable objects (list, dict, numpy.array...) are returned. A reference to the same object in memory is returned each time from cache and not a copy. Then, if this object is modified, the cache itself looses its validity.

Example of this caveat:

>>> array = np.array([1, 2, 3])

>>> @np_cache()
... def to_list(array):
...     print("Calculating...")
...     return array.tolist()

>>> result = to_list(array)
Calculating...
>>> result
[1, 2, 3]

>>> result.append("this shouldn't be here")  # WARNING, DO NOT do this
>>> result
[1, 2, 3, "this shouldn't be here"]

>>> new_result = to_list(array)
>>> result
[1, 2, 3, "this shouldn't be here"]  # CACHE BROKEN!!

To avoid this mutability problem, the usual approaches must be followed. In this case, either list(result) or result[:] will create a (shallow) copy. If result were a nested list, deepcopy must be used. For numpy.array, array.copy() must be used, as neither array[:] nor numpy.array(array) will make a copy.

from functools import lru_cache, wraps
import numpy as np
def np_cache(*args, **kwargs):
"""LRU cache implementation for functions whose FIRST parameter is a numpy array
>>> array = np.array([[1, 2, 3], [4, 5, 6]])
>>> @np_cache(maxsize=256)
... def multiply(array, factor):
... print("Calculating...")
... return factor*array
>>> multiply(array, 2)
Calculating...
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply(array, 2)
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply.cache_info()
CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
"""
def decorator(function):
@wraps(function)
def wrapper(np_array, *args, **kwargs):
hashable_array = array_to_tuple(np_array)
return cached_wrapper(hashable_array, *args, **kwargs)
@lru_cache(*args, **kwargs)
def cached_wrapper(hashable_array, *args, **kwargs):
array = np.array(hashable_array)
return function(array, *args, **kwargs)
def array_to_tuple(np_array):
"""Iterates recursivelly."""
try:
return tuple(array_to_tuple(_) for _ in np_array)
except TypeError:
return np_array
# copy lru_cache attributes over too
wrapper.cache_info = cached_wrapper.cache_info
wrapper.cache_clear = cached_wrapper.cache_clear
return wrapper
return decorator
@mcleantom
Copy link

@domvwt Thanks, havent heard of joblib before, looks very useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment