Skip to content

Instantly share code, notes, and snippets.

@Susensio
Last active October 26, 2023 16:10
Show Gist options
  • Save Susensio/61f4fee01150caaac1e10fc5f005eb75 to your computer and use it in GitHub Desktop.
Save Susensio/61f4fee01150caaac1e10fc5f005eb75 to your computer and use it in GitHub Desktop.
Make function of numpy array cacheable

How to cache slow functions with numpy.array as function parameter on Python

TL;DR

from numpy_lru_cache_decorator import np_cache

@np_cache()
def function(array):
    ...

Explanation

Sometimes processing numpy arrays can be slow, even more if we are doing image analysis. Simply using functools.lru_cache won't work because numpy.array is mutable and not hashable. This workaround allows caching functions that take an arbitrary numpy.array as first parameter, other parameters are passed as is. Decorator accepts lru_cache standard parameters (maxsize=128, typed=False).

Example:

>>> array = np.array([[1, 2, 3], [4, 5, 6]])

>>> @np_cache(maxsize=256)
... def multiply(array, factor):
...     print("Calculating...")
...     return factor*array

>>> product = multiply(array, 2)
Calculating...
>>> product
array([[ 2,  4,  6],
       [ 8, 10, 12]])

>>> multiply(array, 2)
array([[ 2,  4,  6],
       [ 8, 10, 12]])

Warning: about lru_cache decorator caveats

User must be very careful when mutable objects (list, dict, numpy.array...) are returned. A reference to the same object in memory is returned each time from cache and not a copy. Then, if this object is modified, the cache itself looses its validity.

Example of this caveat:

>>> array = np.array([1, 2, 3])

>>> @np_cache()
... def to_list(array):
...     print("Calculating...")
...     return array.tolist()

>>> result = to_list(array)
Calculating...
>>> result
[1, 2, 3]

>>> result.append("this shouldn't be here")  # WARNING, DO NOT do this
>>> result
[1, 2, 3, "this shouldn't be here"]

>>> new_result = to_list(array)
>>> result
[1, 2, 3, "this shouldn't be here"]  # CACHE BROKEN!!

To avoid this mutability problem, the usual approaches must be followed. In this case, either list(result) or result[:] will create a (shallow) copy. If result were a nested list, deepcopy must be used. For numpy.array, array.copy() must be used, as neither array[:] nor numpy.array(array) will make a copy.

from functools import lru_cache, wraps
import numpy as np
def np_cache(*args, **kwargs):
"""LRU cache implementation for functions whose FIRST parameter is a numpy array
>>> array = np.array([[1, 2, 3], [4, 5, 6]])
>>> @np_cache(maxsize=256)
... def multiply(array, factor):
... print("Calculating...")
... return factor*array
>>> multiply(array, 2)
Calculating...
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply(array, 2)
array([[ 2, 4, 6],
[ 8, 10, 12]])
>>> multiply.cache_info()
CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
"""
def decorator(function):
@wraps(function)
def wrapper(np_array, *args, **kwargs):
hashable_array = array_to_tuple(np_array)
return cached_wrapper(hashable_array, *args, **kwargs)
@lru_cache(*args, **kwargs)
def cached_wrapper(hashable_array, *args, **kwargs):
array = np.array(hashable_array)
return function(array, *args, **kwargs)
def array_to_tuple(np_array):
"""Iterates recursivelly."""
try:
return tuple(array_to_tuple(_) for _ in np_array)
except TypeError:
return np_array
# copy lru_cache attributes over too
wrapper.cache_info = cached_wrapper.cache_info
wrapper.cache_clear = cached_wrapper.cache_clear
return wrapper
return decorator
@kuchi
Copy link

kuchi commented Jul 16, 2021

Thanks for the helpful writeup. One additional tip, if you are returning a numpy array from the cached function, you can make the numpy array immutable by setting the WRITEABLE flag. If your return variable is x, then:

@np_cache()
def my_cached_func(in):
....
x.flags['WRITEABLE'] = False # prevent it from being changed
return x

@mcleantom
Copy link

Is there any way to modify this for a class function, so that the second item is the numpy array?

@MischaPanch
Copy link

MischaPanch commented Nov 30, 2021

@Susensio Thanks for the snippet! @mcleantom I also needed this functionality so here a slightly extended version of the snippet, where it is possible to configure the position of the numpy array. I also used @Argysh 's suggestion to optimize runtime for low-dimensional arrays. You should change to the original version if you need it to work with higher dimensional arrays

def np_cache(*lru_args, array_argument_index=0, **lru_kwargs):
    """
    LRU cache implementation for functions whose parameter at ``array_argument_index`` is a numpy array of dimensions <= 2

    Example:
    >>> from sem_env.utils.cache import np_cache
    >>> array = np.array([[1, 2, 3], [4, 5, 6]])
    >>> @np_cache(maxsize=256)
    ... def multiply(array, factor):
    ...     return factor * array
    >>> multiply(array, 2)
    >>> multiply(array, 2)
    >>> multiply.cache_info()
    CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
    """

    def decorator(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            np_array = args[array_argument_index]
            if len(np_array.shape) > 2:
                raise RuntimeError(
                    f"np_cache is currently only supported for arrays of dim. less than 3 but got shape: {np_array.shape}"
                )
            hashable_array = tuple(map(tuple, np_array))
            args = list(args)
            args[array_argument_index] = hashable_array
            return cached_wrapper(*args, **kwargs)

        @lru_cache(*lru_args, **lru_kwargs)
        def cached_wrapper(*args, **kwargs):
            hashable_array = args[array_argument_index]
            array = np.array(hashable_array)
            args = list(args)
            args[array_argument_index] = array
            return function(*args, **kwargs)

        # copy lru_cache attributes over too
        wrapper.cache_info = cached_wrapper.cache_info
        wrapper.cache_clear = cached_wrapper.cache_clear
        return wrapper

    return decorator

@domvwt
Copy link

domvwt commented Dec 7, 2021

FYI, the joblib library implements caching arrays and dataframes to disk. It's one of the key libraries that scikit-learn depends on so it's well tried and tested.

@mcleantom
Copy link

@domvwt Thanks, havent heard of joblib before, looks very useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment