Last active
December 28, 2021 03:03
-
-
Save ciaranchen/d8a98c5ac773ec126668bffbfce3b007 to your computer and use it in GitHub Desktop.
cachelib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import hashlib | |
import pickle | |
import csv | |
import json | |
import logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%m-%d %H:%M:%S') | |
logger = logging.getLogger(__name__) | |
# manipulate the list input arguements | |
def input_args_decorator(self, i, cached_obj): | |
def _cache_decorator(func): | |
def res_func(*args, **kwargs): | |
# FIXME: 强制要求args的长度大于 i | |
# 当i为-1时,直接将整个args作为cached_obj | |
if i == -1: | |
args = cached_obj | |
else: | |
args[i] = cached_obj | |
func(*args, **kwargs) | |
return res_func | |
return _cache_decorator | |
# manipulate the keywords input arguements | |
def input_kwargs_decorator(self, key, cached_obj): | |
def _cache_decorator(func): | |
def res_func(*args, **kwargs): | |
# 当key为None时,将整个cached_obj作为kwargs | |
if not key: | |
kwargs = cached_obj | |
else: | |
kwargs[key] = cached_obj | |
func(*args, **kwargs) | |
return res_func | |
return _cache_decorator | |
# manipulate any input arguements with lambda. | |
def input_lambda_decorator(self, l_func, cache_obj): | |
def _cache_decorator(func): | |
def res_func(*args, **kwargs): | |
args, kwargs = l_func(cache_obj, *args, **kwargs) | |
func(*args, **kwargs) | |
return res_func | |
return _cache_decorator | |
def cache_wrapper(pc): | |
def _cache_decorator(func): | |
def res_func(*args, **kwargs): | |
res = pc.load(func, *args, **kwargs) | |
return res | |
return res_func | |
return _cache_decorator | |
def dump_decorator(pc): | |
def _cache_decorator(func): | |
def res_func(*args, **kwargs): | |
res = pc.dump(func, *args, **kwargs) | |
return res | |
return res_func | |
return _cache_decorator | |
class PickleCacheDir(object): | |
@staticmethod | |
def cache_decorator(*args, **kwargs): | |
return cache_wrapper(PickleCacheDir(*args, **kwargs)) | |
def __init__(self, cache_dir='cache'): | |
self.split_args = '-' | |
self.split_kwargs = '-' | |
self.split_kv = ':' | |
self.split_char = '@' | |
cache_dir = 'cache' + os.sep + \ | |
cache_dir if cache_dir is not None and isinstance( | |
cache_dir, str) else 'cache' | |
self.cache_dir = cache_dir | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def get_id_name(self, func, *args, **kwargs): | |
args_strings = self.split_args.join([str(a) for a in args]) | |
kwargs_strings = self.split_kwargs.join( | |
[str(k) + self.split_kv + str(v) for k, v in kwargs.items()]) | |
return self.split_char.join([func.__name__, args_strings, kwargs_strings]) | |
def get_filename(self, func, *args, **kwargs): | |
final_id = self.get_id_name(func, *args, **kwargs) | |
m = hashlib.md5() | |
m.update(final_id.encode("utf-8")) | |
return os.path.join(self.cache_dir, m.hexdigest() + '.pickle') | |
def _dump(self, filename, result): | |
with open(filename, 'wb') as fp: | |
pickle.dump(result, fp) | |
def _load(self, filename): | |
with open(filename, 'rb') as fp: | |
return pickle.load(fp) | |
def dump(self, func, *args, **kwargs): | |
filename = self.get_filename(func, *args, **kwargs) | |
logger.info('dump {} to {}'.format( | |
self.get_id_name(func, *args, **kwargs), filename)) | |
result = func(*args, **kwargs) | |
self._dump(filename, result) | |
return result | |
def load(self, func, *args, **kwargs): | |
filename = self.get_filename(func, *args, **kwargs) | |
if os.path.exists(filename): | |
logger.info('Loading {} from {}'.format( | |
self.get_id_name(func, *args, **kwargs), filename)) | |
res = self._load(filename) | |
return res | |
else: | |
return self.dump(func, *args, **kwargs) | |
class PickleCacheFile(PickleCacheDir): | |
@staticmethod | |
def cache_decorator(*args, **kwargs): | |
return cache_wrapper(PickleCacheFile(*args, **kwargs)) | |
def __init__(self, path): | |
self.path = path | |
def get_filename(self, func, *args, **kwargs): | |
return self.path | |
def get_id_name(self, func, *args, **kwargs): | |
return func.__name__ | |
class CsvCacheFile(PickleCacheFile): | |
@staticmethod | |
def dump_decorator(*args, **kwargs): | |
return dump_decorator(CsvCacheFile(*args, **kwargs)) | |
def __init__(self, path, skip_header=False): | |
path = path if path is not None and isinstance( | |
path, str) else 'cache.csv' | |
path = path if path.endswith('.csv') else path + '.csv' | |
self.path = path | |
self.skip_header = skip_header | |
def _dump(self, filename, result): | |
if not isinstance(result, list): | |
raise Exception("should receive list obj: " + result) | |
with open(filename, 'w') as fp: | |
writer = csv.writer(fp) | |
writer.writerows(result) | |
def _load(self, filename): | |
with open(filename, 'rb') as fp: | |
reader = csv.reader(fp) | |
if self.skip_header: | |
header = next(reader) | |
return [r for r in reader] | |
class JsonCacheFile(PickleCacheFile): | |
def __init__(self, path, skip_header=False): | |
path = path if path is not None and isinstance( | |
path, str) else 'cache.json' | |
path = path if path.endswith('.json') else path + '.json' | |
self.path = path | |
def get_id_name(self, func, *args, **kwargs): | |
return func.__name__ | |
def _dump(self, filename, result): | |
if not isinstance(result, dict): | |
logger.warning("maybe you should use PickleCacheFile: " + result) | |
with open(filename, 'w') as fp: | |
json.dump(result, fp) | |
def _load(self, filename): | |
with open(filename, 'rb') as fp: | |
return json.load(fp) | |
# support for old code | |
cache_with_filename = PickleCacheFile.cache_decorator | |
cache_decorator = PickleCacheDir.cache_decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from cachelib import PickleCacheDir, PickleCacheFile, CsvCacheFile, JsonCacheFile | |
import logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%m-%d %H:%M:%S') | |
logger = logging.getLogger(__name__) | |
# TODO: code for numpy or pandas cache. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment