Skip to content

Instantly share code, notes, and snippets.

@ciaranchen
Last active December 28, 2021 03:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ciaranchen/d8a98c5ac773ec126668bffbfce3b007 to your computer and use it in GitHub Desktop.
Save ciaranchen/d8a98c5ac773ec126668bffbfce3b007 to your computer and use it in GitHub Desktop.
cachelib
import os
import hashlib
import pickle
import csv
import json
import logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# manipulate the list input arguements
def input_args_decorator(self, i, cached_obj):
def _cache_decorator(func):
def res_func(*args, **kwargs):
# FIXME: 强制要求args的长度大于 i
# 当i为-1时,直接将整个args作为cached_obj
if i == -1:
args = cached_obj
else:
args[i] = cached_obj
func(*args, **kwargs)
return res_func
return _cache_decorator
# manipulate the keywords input arguements
def input_kwargs_decorator(self, key, cached_obj):
def _cache_decorator(func):
def res_func(*args, **kwargs):
# 当key为None时,将整个cached_obj作为kwargs
if not key:
kwargs = cached_obj
else:
kwargs[key] = cached_obj
func(*args, **kwargs)
return res_func
return _cache_decorator
# manipulate any input arguements with lambda.
def input_lambda_decorator(self, l_func, cache_obj):
def _cache_decorator(func):
def res_func(*args, **kwargs):
args, kwargs = l_func(cache_obj, *args, **kwargs)
func(*args, **kwargs)
return res_func
return _cache_decorator
def cache_wrapper(pc):
def _cache_decorator(func):
def res_func(*args, **kwargs):
res = pc.load(func, *args, **kwargs)
return res
return res_func
return _cache_decorator
def dump_decorator(pc):
def _cache_decorator(func):
def res_func(*args, **kwargs):
res = pc.dump(func, *args, **kwargs)
return res
return res_func
return _cache_decorator
class PickleCacheDir(object):
@staticmethod
def cache_decorator(*args, **kwargs):
return cache_wrapper(PickleCacheDir(*args, **kwargs))
def __init__(self, cache_dir='cache'):
self.split_args = '-'
self.split_kwargs = '-'
self.split_kv = ':'
self.split_char = '@'
cache_dir = 'cache' + os.sep + \
cache_dir if cache_dir is not None and isinstance(
cache_dir, str) else 'cache'
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
def get_id_name(self, func, *args, **kwargs):
args_strings = self.split_args.join([str(a) for a in args])
kwargs_strings = self.split_kwargs.join(
[str(k) + self.split_kv + str(v) for k, v in kwargs.items()])
return self.split_char.join([func.__name__, args_strings, kwargs_strings])
def get_filename(self, func, *args, **kwargs):
final_id = self.get_id_name(func, *args, **kwargs)
m = hashlib.md5()
m.update(final_id.encode("utf-8"))
return os.path.join(self.cache_dir, m.hexdigest() + '.pickle')
def _dump(self, filename, result):
with open(filename, 'wb') as fp:
pickle.dump(result, fp)
def _load(self, filename):
with open(filename, 'rb') as fp:
return pickle.load(fp)
def dump(self, func, *args, **kwargs):
filename = self.get_filename(func, *args, **kwargs)
logger.info('dump {} to {}'.format(
self.get_id_name(func, *args, **kwargs), filename))
result = func(*args, **kwargs)
self._dump(filename, result)
return result
def load(self, func, *args, **kwargs):
filename = self.get_filename(func, *args, **kwargs)
if os.path.exists(filename):
logger.info('Loading {} from {}'.format(
self.get_id_name(func, *args, **kwargs), filename))
res = self._load(filename)
return res
else:
return self.dump(func, *args, **kwargs)
class PickleCacheFile(PickleCacheDir):
@staticmethod
def cache_decorator(*args, **kwargs):
return cache_wrapper(PickleCacheFile(*args, **kwargs))
def __init__(self, path):
self.path = path
def get_filename(self, func, *args, **kwargs):
return self.path
def get_id_name(self, func, *args, **kwargs):
return func.__name__
class CsvCacheFile(PickleCacheFile):
@staticmethod
def dump_decorator(*args, **kwargs):
return dump_decorator(CsvCacheFile(*args, **kwargs))
def __init__(self, path, skip_header=False):
path = path if path is not None and isinstance(
path, str) else 'cache.csv'
path = path if path.endswith('.csv') else path + '.csv'
self.path = path
self.skip_header = skip_header
def _dump(self, filename, result):
if not isinstance(result, list):
raise Exception("should receive list obj: " + result)
with open(filename, 'w') as fp:
writer = csv.writer(fp)
writer.writerows(result)
def _load(self, filename):
with open(filename, 'rb') as fp:
reader = csv.reader(fp)
if self.skip_header:
header = next(reader)
return [r for r in reader]
class JsonCacheFile(PickleCacheFile):
def __init__(self, path, skip_header=False):
path = path if path is not None and isinstance(
path, str) else 'cache.json'
path = path if path.endswith('.json') else path + '.json'
self.path = path
def get_id_name(self, func, *args, **kwargs):
return func.__name__
def _dump(self, filename, result):
if not isinstance(result, dict):
logger.warning("maybe you should use PickleCacheFile: " + result)
with open(filename, 'w') as fp:
json.dump(result, fp)
def _load(self, filename):
with open(filename, 'rb') as fp:
return json.load(fp)
# support for old code
cache_with_filename = PickleCacheFile.cache_decorator
cache_decorator = PickleCacheDir.cache_decorator
import os
import numpy as np
from cachelib import PickleCacheDir, PickleCacheFile, CsvCacheFile, JsonCacheFile
import logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# TODO: code for numpy or pandas cache.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment