Last active
June 18, 2023 03:21
-
-
Save wassname/e3c0fb1c7b9119766a7a3b008d2010dc to your computer and use it in GitHub Desktop.
a simple pandas and pickle cache for complex situations, like deep learning where you can't easily cachebust based on the model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implements on disk caching of transformed dataframes | |
Used on a function that returns a single pandas object, | |
this decorator will execute the function, cache the dataframe as a pickle | |
file using the hash of function and subdirectory, and the arguments and filename. | |
The next time the function runs, if the hashes match what is on disk, the decoratored function will simply load and return | |
the pickled pandas object. | |
This can result in speedups of 10 to 100 times, or more, depending on the | |
complexity of the function that creates the dataframe. | |
The caveat is that previously cached dataframes will remain on disk. | |
modified from https://github.com/N2ITN/pandas_cache | |
src: https://gist.github.com/wassname/e3c0fb1c7b9119766a7a3b008d2010dc | |
Usage: | |
``` | |
@pd_cache(Path('.pd_cache_0203')) | |
def read_my_csv(f): | |
... | |
return df | |
read_my_csv('file.csv') | |
``` | |
""" | |
from functools import wraps | |
import pandas as pd | |
import pickle | |
import hashlib | |
import inspect | |
from pathlib import Path | |
import logging | |
logger = logging.getLogger(__name__) | |
def md5hash(s: str) -> str: | |
return hashlib.md5(s).hexdigest() | |
def source_code(func) -> str: | |
return "".join(inspect.getsourcelines(func)[0]) | |
def pd_cache(cache_base: Path = Path(".pd_cache"), use_code: bool = True): | |
def _pd_cache(func): | |
@wraps(func) | |
def cache(*args, **kw): | |
# The subdirectory contains hahs of function name (and optionally code) | |
if use_code: | |
f_hash = md5hash(source_code(func).encode("utf-8"))[:6] | |
else: | |
f_hash = "-1" | |
cache_dir = cache_base / f"{func.__name__}_{f_hash}" | |
if not cache_dir.exists(): | |
cache_dir.mkdir(exist_ok=True, parents=True) | |
logger.info(f"created `{cache_dir}` dir") | |
# The file name contains the hash of functions args and kwargs | |
key = pickle.dumps(args, 1) + pickle.dumps(kw, 1) | |
hsh = md5hash(key)[:6] | |
f = cache_dir / f"{hsh}.pkl.gz" | |
if f.exists(): | |
df = pd.read_pickle(f) | |
logger.debug(f"\t | read {f}") | |
return df | |
else: | |
# Write new | |
df = func(*args, **kw) | |
df.to_pickle(f) | |
logger.debug(f"\t | wrote {f}") | |
return df | |
return cache | |
return _pd_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A simple way of caching, where some arguments are not easily picked, so use the string of them. | |
Usage: | |
```py | |
@cache_pickle(Path('.pkl_cache'), ignore=['batch_size'], kwargs_str=['model']) | |
def preproc_data(model, batch_size, dataset): | |
... | |
return data | |
``` | |
""" | |
from functools import wraps | |
import pickle | |
import hashlib | |
from pathlib import Path | |
import logging | |
logger = logging.getLogger(__name__) | |
cache_dir = Path(".pkl_cache") | |
cache_dir.mkdir(parents=True, exist_ok=True) | |
def md5hash(s: str) -> str: | |
return hashlib.md5(s).hexdigest() | |
def cache_pickle(func, kwargs_str:list=[], kwargs_ignore:list=[]): | |
def wrap(**kwargs): | |
"""wrapper to cache results""" | |
hash_kwargs = {k:(v if k not in kwargs_str else k:str(v)) for k,v in kwargs.items() if k not in kwargs_ignore} | |
logger.debug(f"kwargs {hash_kwargs}") | |
# The file name contains the hash of functions args and kwargs | |
key = pickle.dumps(hash_kwargs, 1) | |
hsh = md5hash(key)[:6] | |
f = cache_dir / f"{hsh}.pkl" | |
if f.exists(): | |
logger.info(f"loading hs from {f}") | |
res = pickle.load(f.open('rb')) | |
else: | |
res = func(**kwargs) | |
logger.info(f"caching hs to {f}") | |
pickle.dump(res, f.open('wb')) | |
return res | |
return wrap |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment