Skip to content

Instantly share code, notes, and snippets.

@maharjun
Created May 4, 2023 06:20
Show Gist options
  • Save maharjun/86764ce33230d684c884aa3b3d0a4a33 to your computer and use it in GitHub Desktop.
Save maharjun/86764ce33230d684c884aa3b3d0a4a33 to your computer and use it in GitHub Desktop.
Function Disk Cache
"""
This module provides utilities for caching the results of time-consuming functions
to disk for faster retrieval in future runs. The main classes and functions are:
- hash_det: A function that generates a deterministic hash value for an object
(immune to changes in its id).
- function_takes_no_arguments: A helper function that checks if a given function
takes no arguments.
- CachedResultCallable: A class that implements a disk-cached version of a callable.
The result is cached in a specified location and retrieved from the disk cache in
future runs. This class can also handle dependencies to other caches, updating
the cache if the dependency cache files have been updated more recently.
- cache_result: A decorator that exposes the CachedResultCallable functionality,
allowing for a more readable syntax when using disk caching for functions.
Usage examples can be found in the documentation of the CachedResultCallable and
cache_result classes.
Note: CachedResultCallable can only be initialized with a callable that takes no
arguments. This is because its purpose is to retrieve the results of long
computations from the disk, and we don't expect multiple function calls with
different parameters.
"""
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2023, maharjun
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
###############################################################################
from __future__ import annotations
import os
from os.path import join as opj
import dill
from typing import Any, Iterable
from typing import Callable
import inspect
import logging
logger = logging.getLogger('utils.generic.cacheutils')
cached_callable_name_set = set()
def hash_det(object_to_hash, n_hex_digits=10):
"""
Gets a deterministic hash value for an object (immune to changes in its id).
This serializes the object via dill and returns an md5 hash computed on the
dill dump.
"""
from utils.generic.dillshim import dill
import hashlib
m = hashlib.md5()
m.update(dill.dumps(object_to_hash))
return m.hexdigest()[:n_hex_digits]
def function_takes_no_arguments(func: Callable):
args_tuple = inspect.getargspec(func)
if args_tuple.keywords is not None:
return False
if args_tuple.varargs is not None:
return False
if args_tuple.defaults is not None:
return False
if len(args_tuple.args) > 1:
return False
elif not hasattr(args_tuple, '__self__') and len(args_tuple.args) > 0:
return False
return True
class CachedResultCallable:
"""
This class implements a disk cached version of a callable where the result is
cached in the specified location and retrieved from the disk cache in the
future. Additionally one can specify dependencies to other caches in which
case, the cache is updated if the dependency cache files have been updated more
recently.
Example
-------
The following example is a basic example::
A = np.rand(100, 100)
B = np.rand(100, 100)
def long_computation():
return np.sum(A @ B, axis=0)
cached_long_computation = CachedResultCallable(long_computation,
cache_dir='.',
cache_name='rand_multiply_sum',
key_to_hash=(A, B))
D = long_computation() # takes a while the first time. If this code isrerun,
# simply reads result from disk cache
NOTE: CachedResultCallable can only be initialized with a callable that takes
no arguments. This is because it's purpose is to retrieve the results of long
computations from the disk. Here we don't expect multiple function calls with
different parameters. This is more applicable to cache closures that
encapsulate the logic of these involved computations
A more readable way of doing the same thing is to use the decorator
cache_result::
A = np.rand(100, 100)
B = np.rand(100, 100)
@cache_result(cache_dir='.', cache_name='rand_multiply_sum', key_to_hash=(A, B))
def long_computation():
return np.sum(A @ B, axis=0)
D = long_computation() # takes a while the first time. If this code isrerun,
# simply reads result from disk cache
The various options, and the specifications of dependencies can be seen in the
documentation of __init__()
"""
def __init__(self, function: Callable,
cache_dir: str, cache_name: str, key_to_hash: Any = None,
perform_cache: bool = True,
dependency_callables: Iterable[CachedResultCallable] = []):
"""
Parameters
----------
function: Callable[[], Any]
This must be a callable that takes no arguments. The return value of
this callable is cached
cache_dir: str
The directory where the cache file will be created
cache_name: str
The name of the cache file that will be created is <cache_name>.p
key_to_hash: Any
Can be any serializable object. This object is serialized and an md5
hash is computed and appended to the filename. If unspecified, the file
name is entirely specified by cache_name
perform_cache: bool (default: True)
A boolean flag to indicate whether to perform caching at all. If True,
the cache is always recomputed. Note that even if t
dependency_callables: Iterable[CachedResultCallable] (default: [])
If this cache depends on any other cached callables they can be
specified here. In case any of those caches are updated more recently
than the current cache, the cache is invalidated and the next function
evaluation will recalculate the cache value.
"""
assert function_takes_no_arguments(function), \
"CachedResultCallable only accepts functions that take no arguments"
self._function = function
if not isinstance(cache_dir, str):
raise TypeError("The cache_dir must be a string")
if not os.path.isdir(cache_dir):
raise ValueError("The cache_dir must point to a directory that already exists")
self._cache_dir = cache_dir
if not isinstance(cache_name, str):
raise TypeError("The cache_name must be a string")
self._cache_name = cache_name
if key_to_hash is not None:
try:
self._hash = hash_det(key_to_hash)
except dill.PicklingError:
raise TypeError("The 'key_to_hash' must be serializable using dill")
self._key_to_hash = key_to_hash
else:
self._key_to_hash = None
self._dependency_callables = list(dependency_callables)
# initializes self._consistent_at_init
self._perform_cache = perform_cache
self._is_consistent_at_init = self.cache_exists_and_is_consistent
# Verify unique path
if self.full_path in cached_callable_name_set:
raise ValueError(f'It appears that the cache path {self.full_path} is already in use in another CachedCallable')
cached_callable_name_set.add(self.full_path)
@property
def is_recomputed(self):
return not self._is_consistent_at_init
@property
def time_of_update(self):
if not hasattr(self, '_time_of_update'):
if self._perform_cache:
if os.path.isfile(self.full_path):
self._time_of_update = os.path.getmtime(self.full_path)
else:
self._time_of_update = None
else:
self._time_of_update = None
return self._time_of_update
@property
def full_name(self):
if self._key_to_hash is None:
return f"{self._cache_name}.p"
else:
return f"{self._cache_name}_{self._hash}.p"
@property
def full_path(self):
return opj(self._cache_dir, self.full_name)
@property
def cache_dir(self):
return self._cache_dir
@property
def cached_value(self):
if not hasattr(self, '_cached_value'):
raise AttributeError("The cached value hasn't been assigned / doesn't exist")
return self._cached_value
@cached_value.setter
def cached_value(self, cvalue):
if self.cache_exists_and_is_consistent:
raise AttributeError("Cannot reassign cached_value for cache that already contains a cached value")
self._cached_value = cvalue
@property
def cache_exists_and_is_consistent(self):
if self.time_of_update is None:
return False
if any(x.time_of_update is None or x.time_of_update > self.time_of_update
for x in self._dependency_callables):
return False
return True
def __enter__(self):
if self.cache_exists_and_is_consistent:
with open(self.full_path, 'rb') as fin:
self._cached_value = dill.load(fin)
return self
def __exit__(self, *args):
import warnings
if not self.cache_exists_and_is_consistent:
if hasattr(self, '_cached_value'):
with open(self.full_path, 'wb') as fout:
dill.dump(self._cached_value, fout, protocol=-1)
self._time_of_update = os.path.getmtime(self.full_path) # overwrite None value
del self._cached_value
else:
warnings.warn("No cache was created by the CachedResultCallable since the cached value wasn't assigned")
def __call__(self):
"""
Calls the callable and caches the result
"""
with self as C:
if not C.cache_exists_and_is_consistent:
C.cached_value = self._function()
return C.cached_value
class cache_result:
"""
Exposes the CachedResultCallable as a decorator where the function argument is
passed in with the call function
"""
def __init__(self, cache_dir: str, cache_name: str, key_to_hash: Any = None,
dependency_callables: Iterable[CachedResultCallable] = [],
perform_cache: bool = True):
self.cache_dir = cache_dir
self.cache_name = cache_name
self.key_to_hash = key_to_hash
self.dependency_callables = dependency_callables
self.perform_cache = perform_cache
def __call__(self, func: Callable) -> CachedResultCallable:
return CachedResultCallable(func,
cache_dir=self.cache_dir,
cache_name=self.cache_name,
key_to_hash=self.key_to_hash,
dependency_callables=self.dependency_callables,
perform_cache=self.perform_cache)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment