Skip to content

Instantly share code, notes, and snippets.

@austospumanto
Created August 3, 2019 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save austospumanto/9d1383362f957d7409dbdee63ccfda32 to your computer and use it in GitHub Desktop.
Save austospumanto/9d1383362f957d7409dbdee63ccfda32 to your computer and use it in GitHub Desktop.
Diskcached: Like functools.lru_cache, but pickled to disk
"""
from .diskcached import clear_all, diskcached
@diskcached()
def cached_fn(x):
print(x)
return x + 2
clear_all()
res1 = cached_fn(7)
res2 = cached_fn(70)
res3 = cached_fn(7)
res4 = cached_fn(70)
print((res1, res2, res3, res4))
Prints
--
7
70
(9, 72, 9, 72)
"""
import hashlib
import logging
import os
import shutil
from base64 import b64encode
from functools import wraps
from pathlib import Path
from typing import Callable, Hashable, Any
import pickle5
from .decorators import retry
from .timing import Timer
_REGISTRY = {}
def pdumps(thing):
return pickle5.dumps(thing, protocol=4)
def clear_all():
for funcname, wrapper in _REGISTRY.items():
with Timer(f"Clearing diskcached cache for {repr(funcname)}"):
wrapper.clear()
def get_cache_filename(funcname):
return f"diskcached:{funcname}"
def diskcached(
keyfunc: Callable[[Any, Any], Hashable] = lambda *a, **kw: a
+ tuple(sorted(kw.items())),
):
def decorator(func):
funcname = func.__qualname__
logger = logging.getLogger("diskcached::funcname")
def get_dirp():
dirpath = Path("/tmp") / get_cache_filename(funcname)
if not dirpath.exists():
os.makedirs(str(dirpath), exist_ok=True)
return dirpath
def get_fp(key):
return get_dirp() / key
def add(key, value):
if not hasattr(wrapper, "thecache"):
wrapper.thecache = dict()
wrapper.thecache[key] = value
get_fp(key).write_bytes(pdumps(value))
return value
@retry(delay=0.1, handled_exceptions=(Exception,))
def get(key):
if hasattr(wrapper, "thecache") and key in wrapper.thecache:
return wrapper.thecache[key]
elif get_fp(key).exists():
val = pickle5.loads(get_fp(key).read_bytes())
if hasattr(wrapper, "thecache"):
wrapper.thecache[key] = val
return val
@wraps(func)
def wrapper(*args, **kwargs):
items = keyfunc(*args, **kwargs)
key = hashlib.md5(b64encode(pdumps(items))).hexdigest()
existing = get(key)
if existing is not None:
return existing
else:
try:
return add(key, func(*args, **kwargs))
except Exception as e:
logger.exception(repr(e))
raise
def clear():
if hasattr(wrapper, "thecache"):
delattr(wrapper, "thecache")
if get_dirp().exists():
shutil.rmtree(get_dirp())
_REGISTRY[funcname] = wrapper
wrapper.clear = clear
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment