Skip to content

Instantly share code, notes, and snippets.

@kcsaff
Last active August 29, 2015 14:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kcsaff/eced956b92698c7b98c0 to your computer and use it in GitHub Desktop.
Save kcsaff/eced956b92698c7b98c0 to your computer and use it in GitHub Desktop.
A simple python cache class that allows you to take care & only cache valid data.
"""
Simple caching.
This module contains a `Cache` class that allows for simple caching that is invalidated by a timeout or an
exception while the user attempts to apply the value.
"""
import sqlite3
import time
import threading
from contextlib import closing, contextmanager
from queue import Queue
class Cache(object):
"""
Set up a sqlite cache that will cache values for a timeout period, and can invalidate cached values
if the user rejects the value (by throwing an exception in a context) then the value will not be stored,
so it will be recalculated on the next attempt.
For our example, let's set up a temporary directory really quick.
>>> import tempfile
>>> tempdir = tempfile.TemporaryDirectory()
>>> dirname = tempdir.name
We'll mock out the time, and our cached function will return a string indicating the time it was cached.
>>> current_time = 10000
>>> def cached_function():
... return 'cached at {0}'.format(current_time)
>>>
>>> cache = Cache(dirname + '/temp.db', 'example_table', 1000, time_function=(lambda: current_time))
We need to provide any calculated value with a unique string key that determines the value: when we run this
there's no value stored in the database yet so it will run the calculation and return the new value.
>>> cache.get('cached_function', cached_function)
'cached at 10000'
If we advance time less than the timeout period, re-running the cached function won't actually recalculate
the value.
>>> current_time = 10500
>>> cache.get('cached_function', cached_function)
'cached at 10000'
If we advance time beyond the timeout period, the value must be recalculated.
>>> current_time = 15000
>>> cache.get('cached_function', cached_function)
'cached at 15000'
Now let's look at the context version of caching. The only difference here is that
the value will only be cached if there is no exception thrown during the context.
The following mirror the steps taken above.
>>> current_time = 50000
>>> with cache.context('cached_function', cached_function) as value:
... value
'cached at 50000'
>>> current_time = 50500
>>> with cache.context('cached_function', cached_function) as value:
... value
'cached at 50000'
>>> current_time = 55000
>>> with cache.context('cached_function', cached_function) as value:
... value
'cached at 55000'
However if we throw an exception in the context, even though the newest calculated
value was returned, this value will _not_ be cached. This allows us to use the cache
in such a way to guarantee that only valid data is stored in the cache.
>>> current_time = 60000
>>> try:
... with cache.context('cached_function', cached_function) as value:
... value
... raise Exception('Raising an exception prevents the value from being cached.')
... except:
... pass
'cached at 60000'
We saw the new value inside the context, but since we threw an error that value won't be cached.
>>> cache.get_cached('cached_function')
'cached at 55000'
We could instead invalidate cached values manually if we got them using `get` instead of using a context.
>>> current_time = 70000
>>> cache.get('cached_function', cached_function)
'cached at 70000'
>>> cache.invalidate('cached_function')
>>> cache.get_cached('cached_function') is None
True
"""
def __init__(self, path, table, timeout, serialize=repr, deserialize=eval, deserialize_kwargs=None, time_function=time.time):
"""Create the cache.
:param path: Filename to store the sqlite database.
:param table: Table to store this cache under in the database.
:param timeout: Timeout (in seconds, if the timeout function returns seconds.)
:param serialize: Function used to serialize values into the database. `repr` by default.
:param deserialize: Function used to deserialize values from the database. `eval` by default.
:param deserialize_kwargs: keyword arguments given to `deserialize`. This can be used to provide `**{'locals': ...}`
to `eval` if the `repr` of the value contains names that must be provided.
:param time_function: Returns a float representing the current time -- this is just time.time() by default.
"""
self._lock = threading.RLock()
self._path = path
self._table = table
self._timeout = timeout
self._serialize = serialize
self._deserialize = deserialize
self._deserialize_kwargs = deserialize_kwargs or dict()
self._time_function = time_function
self._table_known_created = False
self._connection_pool = None
self._connection_count = 0
@contextmanager
def connected(self, connections=1):
"""Keep connections open to the sqlite database instead of opening on each operation.
This could improve performance in certain cases. If you want to keep connections open
and are accessing this with multiple threads, pass a parameter indicating the number of
connections to open. Threads will use connections as necessary from the connection pool.
All connections will be closed at the end of this context.
:param connections: indicates the number of connections to make available in a pool.
"""
with self._lock:
self._connection_count += connections
if not self._connection_pool:
self._connection_pool = Queue()
try:
for _ in range(connections):
conn = sqlite3.connect(self._path)
self._connection_pool.put(conn)
try:
yield self
finally:
for _ in range(connections):
conn = self._connection_pool.get()
conn.close()
finally:
with self._lock:
self._connection_count -= connections
if self._connection_count <= 0:
self._connection_count = 0
self._connection_pool = None
def get(self, key, calculate):
"""Get cached value if available & within timeout period; otherwise recalculate & cache the value."""
with self.context(key, calculate) as value:
return value
@contextmanager
def context(self, key, calculate):
"""Use a context to determine whether to cache a calculated value.
:param key: ID to store the value in the cache.
:param calculate: function taking 0 parameters that can calculate the desired value.
This provides the cached or calculated value as the `as` term for a `with` statement.
If an exception is thrown in the context, then the value will be invalidated or otherwise
not stored in the cache.
For example, imagine you are retrieving some JSON through an http API. Then you could use:
with cache.context("my_json_key", (lambda: requests.get(url).json())) as json_data:
if json_data.get('error'):
throw Exception("Don't store this data, we will need to try again later.")
else:
use_the_json_data(json_data)
"""
# Find cached value if possible
cached_value = self.get_cached(key, check_time=True)
if cached_value is not None:
try:
yield cached_value
except:
self.invalidate(key)
raise
else:
calculated_value = calculate()
yield calculated_value
# If we get here without an error, we can store this in the db
self.set_cached(key, calculated_value)
def get_cached(self, key, check_time=False, check_valid=True):
"""Returns the cached value (if available, else `None`).
You can use this if you want a version of your routine that will only use whatever cached values are available,
for instance an "offline mode". In that case, leave `check_time=False` so you get the most recent value
even if it's older than the timeout period.
If you want to retrieve even invalid values from the database then you can set `check_valid=False`.
:param key: ID where value is stored in the cache.
:param check_time: Indicate whether to only return cached value if it's still fresh.
:param check_valid: Indicate whether to only return cached value if it's valid.
"""
cached_value = None
with self._conn_cursor() as (conn, cursor):
self._create_table(conn, cursor)
cursor.execute('SELECT value, time, valid FROM {table} WHERE key=?'.format(table=self._table), (key,))
fetched = cursor.fetchone()
if fetched:
(cached_serialized_value, cached_time, cached_valid) = fetched
time_ok = not check_time or cached_time <= self._time_function() < cached_time + self._timeout
valid_ok = not check_valid or cached_valid > 0
if time_ok and valid_ok:
try:
cached_value = self._deserialize(cached_serialized_value, **self._deserialize_kwargs)
except Exception:
pass # Don't use cached value we can't read
return cached_value
def set_cached(self, key, value):
"""Just update the cached value in the database.
(This generally won't need to be called by a user, since `get` and `context` will automatically cache valid values.)
:param key: ID to store the value in the cache.
:param value: value to store.
"""
with self._conn_cursor() as (conn, cursor):
cursor.execute('INSERT OR REPLACE INTO {table} (key, value, time, valid) VALUES (?, ?, ?, 1)'.format(table=self._table),
(key, self._serialize(value), self._time_function())
)
conn.commit()
def invalidate(self, key):
"""Invalidate the cached value for the given key.
You could use this if later processing proves that the value wasn't any good., so you want to invalidate it from
the database so later runs will attempt to recalculate it.
:param key: ID where value is stored in cache.
"""
with self._conn_cursor() as (conn, cursor):
cursor.execute('UPDATE {table} SET valid=0 WHERE key=?'.format(table=self._table),
(key,)
)
conn.commit()
@contextmanager
def _conn(self, block=True, timeout=None):
"""Obtain a connection that will be closed automatically."""
if self._connection_pool:
conn = self._connection_pool.get(block, timeout)
try:
yield conn
finally:
self._connection_pool.put(conn)
else:
with closing(sqlite3.connect(self._path)) as conn:
yield conn
@contextmanager
def _conn_cursor(self, block=True, timeout=None):
"""Obtain a connection, cursor pair that will be closed automatically."""
with self._conn(block, timeout) as conn:
with closing(conn.cursor()) as cursor:
yield (conn, cursor)
def _create_table(self, conn, cursor):
"""Create the table if it doesn't exist."""
if not self._table_known_created:
# presumably we already have the lock here anyway, but better safe than sorry
with self._lock:
if not self._table_known_created:
cursor.execute(
'''CREATE TABLE IF NOT EXISTS {table} (
key TEXT PRIMARY KEY,
value TEXT,
time REAL,
valid INTEGER DEFAULT 1
) WITHOUT ROWID'''.format(table=self._table)
)
conn.commit()
self._table_known_created = True
if __name__ == '__main__':
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment