Skip to content

Instantly share code, notes, and snippets.

@user01
Created November 17, 2023 02:19
Show Gist options
  • Save user01/45277dce95f70df472ef36e94b7347db to your computer and use it in GitHub Desktop.
Save user01/45277dce95f70df472ef36e94b7347db to your computer and use it in GitHub Desktop.
Simple SQLite3 Local cache decorator
# CC0 License
# usage:
# @cache_result("./some_database_path.sqlite")
# def expensive_function():
# pass
import inspect
import logging
import hashlib
import sqlite3
from functools import wraps
import pickle
from datetime import datetime
from typing import Any, List
logger = logging.getLogger('cache')
RETRYS = 3
def calculate_md5(data):
# Create an MD5 hash object
md5_hash = hashlib.md5()
# Update the hash object with the bytes data
md5_hash.update(data)
# Get the hexadecimal representation of the hash digest
md5_digest = md5_hash.hexdigest()
return md5_digest
class SQLiteCache:
def __init__(self, database_path):
self.database_path = database_path
self.conn = sqlite3.connect(self.database_path)
self.conn.execute(
"""CREATE TABLE IF NOT EXISTS function_results
(id TEXT PRIMARY KEY, result BLOB, timestamp TEXT, duration REAL)"""
)
def cached_count(self) -> int:
cursor = self.conn.execute("SELECT COUNT(*) FROM function_results")
row = cursor.fetchone()
return row[0]
def __del__(self):
self.conn.close()
def store_result(self, unique_id, pickled_result, timestamp, duration):
why_local = None
for _ in range(RETRYS):
try:
self.conn.execute(
"INSERT INTO function_results (id, result, timestamp, duration) VALUES (?, ?, ?, ?)",
(unique_id, sqlite3.Binary(pickled_result), timestamp, duration),
)
self.conn.commit()
return
except sqlite3.OperationalError as why:
why_local = why
raise why_local
def retrieve_result(self, unique_id):
cursor = self.conn.execute(
"SELECT result, timestamp, duration FROM function_results WHERE id = ?",
(unique_id,),
)
row = cursor.fetchone()
if row:
return pickle.loads(row[0]), row[1], row[2]
return None, None, None
def retrieve_results(self):
cursor = self.conn.execute(
"SELECT result, timestamp, duration FROM function_results"
)
rows = cursor.fetchall()
return [(pickle.loads(row[0]), row[1], row[2]) for row in rows]
def cache_result(database_path, verbose: bool = False):
cache = SQLiteCache(database_path)
if verbose:
logger.info(f"Starting with {cache.cached_count():,} cache items")
def decorator(func):
function_name = func.__name__
func_source = inspect.getsource(func)
func_md5 = calculate_md5(func_source.encode())
# logger.info(f"Wrapped {function_name}")
@wraps(func)
def wrapper(*args, **kwargs):
# Pickle the arguments
arguments = pickle.dumps((args, kwargs))
# Create a unique identifier for the function call
unique_id = f"{function_name}__{func_md5}__{calculate_md5(arguments)}"
# Check if the function result is already cached
result, timestamp, duration = cache.retrieve_result(unique_id)
if result is not None:
# If cached result exists, return it
if verbose:
logger.info(
f"Cache hit. Original at {timestamp} took {duration:4.2f} seconds"
)
return result
# If result not cached, execute the function
start_time = datetime.now()
result = func(*args, **kwargs)
duration = (datetime.now() - start_time).total_seconds()
# Pickle the result and store it in the database along with timestamp and duration
pickled_result = pickle.dumps(result)
timestamp = str(datetime.now())
try:
cache.store_result(unique_id, pickled_result, timestamp, duration)
except sqlite3.IntegrityError as exception:
if "" in str(exception):
if verbose:
logger.warning("Attempting to write existing result into database. Skipping write")
else:
raise exception
if verbose:
logger.info(
f"Cache miss. Run at {timestamp} took {duration:4.2f} seconds"
)
return result
def call_count() -> int:
return cache.cached_count()
wrapper.call_count = call_count
def all_results() -> List[Any]:
return cache.retrieve_results()
wrapper.all_results = all_results
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment