Created
November 17, 2023 02:19
-
-
Save user01/45277dce95f70df472ef36e94b7347db to your computer and use it in GitHub Desktop.
Simple SQLite3 Local cache decorator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CC0 License | |
# usage: | |
# @cache_result("./some_database_path.sqlite") | |
# def expensive_function(): | |
# pass | |
import inspect | |
import logging | |
import hashlib | |
import sqlite3 | |
from functools import wraps | |
import pickle | |
from datetime import datetime | |
from typing import Any, List | |
logger = logging.getLogger('cache') | |
RETRYS = 3 | |
def calculate_md5(data): | |
# Create an MD5 hash object | |
md5_hash = hashlib.md5() | |
# Update the hash object with the bytes data | |
md5_hash.update(data) | |
# Get the hexadecimal representation of the hash digest | |
md5_digest = md5_hash.hexdigest() | |
return md5_digest | |
class SQLiteCache: | |
def __init__(self, database_path): | |
self.database_path = database_path | |
self.conn = sqlite3.connect(self.database_path) | |
self.conn.execute( | |
"""CREATE TABLE IF NOT EXISTS function_results | |
(id TEXT PRIMARY KEY, result BLOB, timestamp TEXT, duration REAL)""" | |
) | |
def cached_count(self) -> int: | |
cursor = self.conn.execute("SELECT COUNT(*) FROM function_results") | |
row = cursor.fetchone() | |
return row[0] | |
def __del__(self): | |
self.conn.close() | |
def store_result(self, unique_id, pickled_result, timestamp, duration): | |
why_local = None | |
for _ in range(RETRYS): | |
try: | |
self.conn.execute( | |
"INSERT INTO function_results (id, result, timestamp, duration) VALUES (?, ?, ?, ?)", | |
(unique_id, sqlite3.Binary(pickled_result), timestamp, duration), | |
) | |
self.conn.commit() | |
return | |
except sqlite3.OperationalError as why: | |
why_local = why | |
raise why_local | |
def retrieve_result(self, unique_id): | |
cursor = self.conn.execute( | |
"SELECT result, timestamp, duration FROM function_results WHERE id = ?", | |
(unique_id,), | |
) | |
row = cursor.fetchone() | |
if row: | |
return pickle.loads(row[0]), row[1], row[2] | |
return None, None, None | |
def retrieve_results(self): | |
cursor = self.conn.execute( | |
"SELECT result, timestamp, duration FROM function_results" | |
) | |
rows = cursor.fetchall() | |
return [(pickle.loads(row[0]), row[1], row[2]) for row in rows] | |
def cache_result(database_path, verbose: bool = False): | |
cache = SQLiteCache(database_path) | |
if verbose: | |
logger.info(f"Starting with {cache.cached_count():,} cache items") | |
def decorator(func): | |
function_name = func.__name__ | |
func_source = inspect.getsource(func) | |
func_md5 = calculate_md5(func_source.encode()) | |
# logger.info(f"Wrapped {function_name}") | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
# Pickle the arguments | |
arguments = pickle.dumps((args, kwargs)) | |
# Create a unique identifier for the function call | |
unique_id = f"{function_name}__{func_md5}__{calculate_md5(arguments)}" | |
# Check if the function result is already cached | |
result, timestamp, duration = cache.retrieve_result(unique_id) | |
if result is not None: | |
# If cached result exists, return it | |
if verbose: | |
logger.info( | |
f"Cache hit. Original at {timestamp} took {duration:4.2f} seconds" | |
) | |
return result | |
# If result not cached, execute the function | |
start_time = datetime.now() | |
result = func(*args, **kwargs) | |
duration = (datetime.now() - start_time).total_seconds() | |
# Pickle the result and store it in the database along with timestamp and duration | |
pickled_result = pickle.dumps(result) | |
timestamp = str(datetime.now()) | |
try: | |
cache.store_result(unique_id, pickled_result, timestamp, duration) | |
except sqlite3.IntegrityError as exception: | |
if "" in str(exception): | |
if verbose: | |
logger.warning("Attempting to write existing result into database. Skipping write") | |
else: | |
raise exception | |
if verbose: | |
logger.info( | |
f"Cache miss. Run at {timestamp} took {duration:4.2f} seconds" | |
) | |
return result | |
def call_count() -> int: | |
return cache.cached_count() | |
wrapper.call_count = call_count | |
def all_results() -> List[Any]: | |
return cache.retrieve_results() | |
wrapper.all_results = all_results | |
return wrapper | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment