Skip to content

Instantly share code, notes, and snippets.

@faustomorales
Last active October 28, 2022 19:27
Show Gist options
  • Save faustomorales/3aabc705228e68ed932f8ed3c67ee287 to your computer and use it in GitHub Desktop.
Save faustomorales/3aabc705228e68ed932f8ed3c67ee287 to your computer and use it in GitHub Desktop.
import os
import json
import functools
import hashlib
import pandas as pd
# Taken from https://bobbyhadz.com/blog/python-typeerror-object-of-type-ndarray-is-not-json-serializable
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def CacheResult(cache_dir: str, **pandas_kwargs):
"""A decorator to cache pandas dataframes. It uses
the arguments passed to the decorated function
to construct hashes so that calls to the function
with the same arguments will use a cached version
of the dataframe. Passing "overwrite_cache=True"
during calls will result in the cache being overwritten.
Consider the following example:
```
import time
import pandas as pd
@CacheResult("slow-sequence", index=False)
def get_slow_sequence_dataframe(length):
time.sleep(5)
return pd.DataFrame({"foo": list(range(length))})
df1 = get_slow_sequence_dataframe(51) # Takes 5 seconds.
df2 = get_slow_sequence_dataframe(32) # Takes 5 seconds.
df3 = get_slow_sequence_dataframe(51) # Very fast!
df4 = get_slow_sequence_dataframe(51, overwrite_cache=True) # Takes 5 seconds.
```
Args:
cache_dir: The directory in which to store dataframe
files.
All other arguments passed to the `to_csv()` method.
"""
def decorator(func):
@functools.wraps(func)
def wrapped(*args, overwrite_cache=False, **kwargs):
os.makedirs(cache_dir, exist_ok=True)
h = hashlib.md5()
h.update(json.dumps(kwargs, sort_keys=True, cls=NpEncoder).encode())
h.update(str(args).encode())
filepath = os.path.join(cache_dir, h.hexdigest() + ".csv")
if os.path.isfile(filepath) and not overwrite_cache:
return pd.read_csv(filepath)
df = func(*args, **kwargs)
df.to_csv(filepath, **pandas_kwargs)
return df
return wrapped
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment