Last active
March 31, 2024 05:03
-
-
Save sunfkny/8cd6f767134a56968ceaa0498a8593e8 to your computer and use it in GitHub Desktop.
origin from https://github.com/sweepai/sweep/blob/main/docs/public/file_cache.py , add typehint and improve code quality
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import hashlib | |
import inspect | |
import os | |
import pathlib | |
import pickle | |
import tempfile | |
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar | |
from loguru import logger | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
MAX_DEPTH = 6 | |
DISABLE_CACHE = False | |
def recursive_hash(value: Any, depth: int = 0, ignore_params: list[str] | None = None): | |
"""Hash primitives recursively with maximum depth.""" | |
if ignore_params is None: | |
ignore_params = [] | |
if depth > MAX_DEPTH: | |
return hashlib.md5("max_depth_reached".encode()).hexdigest() | |
if isinstance(value, (int, float, str, bool, bytes)): | |
return hashlib.md5(str(value).encode()).hexdigest() | |
elif isinstance(value, (list, tuple)): | |
return hashlib.md5("".join([recursive_hash(item, depth + 1, ignore_params) for item in value]).encode()).hexdigest() | |
elif isinstance(value, dict): | |
return hashlib.md5( | |
"".join( | |
[ | |
recursive_hash(key, depth + 1, ignore_params) + recursive_hash(val, depth + 1, ignore_params) | |
for key, val in value.items() | |
if key not in ignore_params | |
] | |
).encode() | |
).hexdigest() | |
elif hasattr(value, "__dict__") and value.__class__.__name__ not in ignore_params: | |
return recursive_hash(value.__dict__, depth + 1, ignore_params) | |
else: | |
return hashlib.md5("unknown".encode()).hexdigest() | |
def hash_code(code: str): | |
return hashlib.md5(code.encode()).hexdigest() | |
def cache(ignore_params: list[str] | None = None, verbose: bool = False): | |
"""Decorator to cache function output based on its inputs, ignoring specified parameters. | |
Ignore parameters are used to avoid caching on non-deterministic inputs, such as timestamps. | |
We can also ignore parameters that are slow to serialize/constant across runs, such as large objects. | |
""" | |
if ignore_params is None: | |
ignore_params = [] | |
def decorator(func: Callable[Concatenate[P], R]) -> Callable[P, R]: | |
if DISABLE_CACHE: | |
if verbose: | |
logger.info(f"Cache is disabled for function: {func.__name__}") | |
return func | |
func_source_code_hash = hash_code(inspect.getsource(func)) | |
@functools.wraps(func) | |
def wrapper(*args: P.args, **kwargs: P.kwargs): | |
cache_dir = pathlib.Path(tempfile.gettempdir()) / "file_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Convert args to a dictionary based on the function's signature | |
args_names = func.__code__.co_varnames[: func.__code__.co_argcount] | |
args_dict = dict(zip(args_names, args)) | |
# Remove ignored params | |
kwargs_clone = kwargs.copy() | |
for param in ignore_params: | |
args_dict.pop(param, None) | |
kwargs_clone.pop(param, None) | |
# Create hash based on argument names, argument values, and function source code | |
args_hash = recursive_hash(args_dict, ignore_params=ignore_params) | |
kwargs_hash = recursive_hash(kwargs_clone, ignore_params=ignore_params) | |
func_hash = func_source_code_hash + args_hash + kwargs_hash | |
cache_file = cache_dir / f"{func.__module__}_{func.__name__}_{func_hash}.pickle" | |
try: | |
# If cache exists, load and return it | |
if cache_file.is_file(): | |
if verbose: | |
logger.info(f"Used cache for function: {func.__name__}") | |
with open(cache_file, "rb") as f: | |
return pickle.load(f) | |
except Exception as e: | |
logger.warning(f"Unpickling failed {e}") | |
# Otherwise, call the function and save its result to the cache | |
result = func(*args, **kwargs) | |
try: | |
with open(cache_file, "wb") as f: | |
pickle.dump(result, f) | |
except Exception as e: | |
logger.warning(f"Pickling failed: {e}") | |
return result | |
return wrapper | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment