Skip to content

Instantly share code, notes, and snippets.

@sunfkny
Last active March 31, 2024 05:03
Show Gist options
  • Save sunfkny/8cd6f767134a56968ceaa0498a8593e8 to your computer and use it in GitHub Desktop.
Save sunfkny/8cd6f767134a56968ceaa0498a8593e8 to your computer and use it in GitHub Desktop.
origin from https://github.com/sweepai/sweep/blob/main/docs/public/file_cache.py , add typehint and improve code quality
import functools
import hashlib
import inspect
import os
import pathlib
import pickle
import tempfile
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from loguru import logger
P = ParamSpec("P")
R = TypeVar("R")
MAX_DEPTH = 6
DISABLE_CACHE = False
def recursive_hash(value: Any, depth: int = 0, ignore_params: list[str] | None = None):
"""Hash primitives recursively with maximum depth."""
if ignore_params is None:
ignore_params = []
if depth > MAX_DEPTH:
return hashlib.md5("max_depth_reached".encode()).hexdigest()
if isinstance(value, (int, float, str, bool, bytes)):
return hashlib.md5(str(value).encode()).hexdigest()
elif isinstance(value, (list, tuple)):
return hashlib.md5("".join([recursive_hash(item, depth + 1, ignore_params) for item in value]).encode()).hexdigest()
elif isinstance(value, dict):
return hashlib.md5(
"".join(
[
recursive_hash(key, depth + 1, ignore_params) + recursive_hash(val, depth + 1, ignore_params)
for key, val in value.items()
if key not in ignore_params
]
).encode()
).hexdigest()
elif hasattr(value, "__dict__") and value.__class__.__name__ not in ignore_params:
return recursive_hash(value.__dict__, depth + 1, ignore_params)
else:
return hashlib.md5("unknown".encode()).hexdigest()
def hash_code(code: str):
return hashlib.md5(code.encode()).hexdigest()
def cache(ignore_params: list[str] | None = None, verbose: bool = False):
"""Decorator to cache function output based on its inputs, ignoring specified parameters.
Ignore parameters are used to avoid caching on non-deterministic inputs, such as timestamps.
We can also ignore parameters that are slow to serialize/constant across runs, such as large objects.
"""
if ignore_params is None:
ignore_params = []
def decorator(func: Callable[Concatenate[P], R]) -> Callable[P, R]:
if DISABLE_CACHE:
if verbose:
logger.info(f"Cache is disabled for function: {func.__name__}")
return func
func_source_code_hash = hash_code(inspect.getsource(func))
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs):
cache_dir = pathlib.Path(tempfile.gettempdir()) / "file_cache"
os.makedirs(cache_dir, exist_ok=True)
# Convert args to a dictionary based on the function's signature
args_names = func.__code__.co_varnames[: func.__code__.co_argcount]
args_dict = dict(zip(args_names, args))
# Remove ignored params
kwargs_clone = kwargs.copy()
for param in ignore_params:
args_dict.pop(param, None)
kwargs_clone.pop(param, None)
# Create hash based on argument names, argument values, and function source code
args_hash = recursive_hash(args_dict, ignore_params=ignore_params)
kwargs_hash = recursive_hash(kwargs_clone, ignore_params=ignore_params)
func_hash = func_source_code_hash + args_hash + kwargs_hash
cache_file = cache_dir / f"{func.__module__}_{func.__name__}_{func_hash}.pickle"
try:
# If cache exists, load and return it
if cache_file.is_file():
if verbose:
logger.info(f"Used cache for function: {func.__name__}")
with open(cache_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"Unpickling failed {e}")
# Otherwise, call the function and save its result to the cache
result = func(*args, **kwargs)
try:
with open(cache_file, "wb") as f:
pickle.dump(result, f)
except Exception as e:
logger.warning(f"Pickling failed: {e}")
return result
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment