Skip to content

Instantly share code, notes, and snippets.

@typehorror
Last active February 15, 2024 14:44
Show Gist options
  • Star 58 You must be signed in to star a gist
  • Fork 15 You must be signed in to fork a gist
  • Save typehorror/f26e5ff9756cde90470f to your computer and use it in GitHub Desktop.
Save typehorror/f26e5ff9756cde90470f to your computer and use it in GitHub Desktop.
Flask SQLAlchemy Caching
# file: app.py
# Full article: http://www.debrice.com/flask-sqlalchemy-caching/
import random
from flask import Flask
from flask.ext.sqlalchemy import SQLAlchemy
from caching import CacheableMixin, regions, query_callable
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:////tmp/test.db'
app.debug = True
db = SQLAlchemy(app)
# To generate names and email in the DB
FIRST_NAMES = (
"JAMES", "JOHN", "ROBERT", "MICHAEL", "WILLIAM", "DAVID", "RICHARD", "CHARLES", "JOSEPH",
"THOMAS", "CHRISTOPHER", "DANIEL", "PAUL", "MARK", "DONALD", "GEORGE", "KENNETH",
"STEVEN", "EDWARD", "BRIAN", "RONALD", "ANTHONY", "KEVIN", "JASON", "MATTHEW", "GARY",
"TIMOTHY", "JOSE", "LARRY", "JEFFREY", "FRANK", "SCOTT", "ERIC", "STEPHEN", "ANDREW",
"RAYMOND", "GREGORY", "JOSHUA", "JERRY", "DENNIS", "WALTER", "PATRICK", "PETER", "HAROLD")
LAST_NAMES = (
"SMITH", "JOHNSON", "WILLIAMS", "JONES", "BROWN", "DAVIS", "MILLER", "WILSON", "MOORE",
"TAYLOR", "ANDERSON", "THOMAS", "JACKSON", "WHITE", "HARRIS", "MARTIN", "THOMPSON",
"GARCIA", "MARTINEZ", "ROBINSON", "CLARK", "RODRIGUEZ", "LEWIS", "LEE", "WALKER", "HALL",
"ALLEN", "YOUNG", "HERNANDEZ", "KING", "WRIGHT", "LOPEZ", "HILL", "SCOTT", "GREEN")
DOMAINS = ['gmail.com', 'yahoo.com', 'msn.com', 'facebook.com', 'aol.com', 'att.com']
class User(CacheableMixin, db.Model):
cache_label = "default"
cache_regions = regions
#cache_pk = "username" # for custom pk
query_class = query_callable(regions)
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80))
email = db.Column(db.String(120))
views = db.Column(db.Integer, default=0)
def __init__(self, username, email):
self.username = username
self.email = email
def __repr__(self):
return '<User %r>' % self.username
@app.route('/views/<int:views>/')
def user_with_x_views(views):
html_lines = []
for user in User.cache.filter(views=views):
html_lines.append(
"""<td>%s</td>
<td>%s</td>
<td>%s</td>
<td><a href="/update/%s/">update</a></td>
<td><a href="/%s/">view (%s)</a></td>""" % \
(user.id, user.username, user.email, user.id, user.id, user.views))
return '<table><tr>%s</tr></table>' % '</tr><tr>'.join(html_lines)
@app.route('/')
def all_user():
html_lines = []
# Cache alternative to User.query.filter()
# We could also use User.query.options(User.cache.from_cache("my cache")).filter()
# and we would manually invalidate "my_cache":
# User.cache.flush("my_cache")
for user in User.cache.filter():
html_lines.append("""
<td>%s</td>
<td>%s</td>
<td>%s</td>
<td><a href="/update/%s/">update</a></td>
<td><a href="/%s/">view (%s)</a></td>""" % \
(user.id, user.username, user.email, user.id, user.id, user.views))
return '<table><tr>%s</tr></table>' % '</tr><tr>'.join(html_lines)
@app.route('/update/<int:user_id>/')
def update_user(user_id):
# alternative from User.query.get(user_id)
user = User.cache.get(user_id)
# updating views count will clear listing related to the previous
# views value, the new views value, the "all" unfiltered listing
# and the object cache itself
user.views = user.views + 1
db.session.add(user)
db.session.commit()
return '<h1>%s</h1><p>email: %s<br>views: %s</p><a href="/">back</a>' % \
(user.username, user.email, user.views)
@app.route('/<int:user_id>/')
def view_user(user_id):
# alternative from User.query.get(user_id)
user = User.cache.get(user_id)
return '<h1>%s</h1><p>email: %s<br>views: %s</p><a href="/">back</a>' % \
(user.username, user.email, user.views)
def random_user():
first_name = random.choice(FIRST_NAMES)
last_name = random.choice(LAST_NAMES)
email = "%s.%s@%s" % (first_name, last_name, random.choice(DOMAINS))
return User(username="%s_%s" % (first_name, last_name), email=email)
@app.route('/init_db/')
def init_db():
db.drop_all()
db.create_all()
for i in range(50):
db.session.add(random_user())
db.session.commit()
return 'DB initialized'
if __name__ == '__main__':
app.run()
# file: caching.py
# Full article: http://www.debrice.com/flask-sqlalchemy-caching/
import functools
import hashlib
from flask.ext.sqlalchemy import BaseQuery
from sqlalchemy import event, select
from sqlalchemy.orm.interfaces import MapperOption
from sqlalchemy.orm.attributes import get_history
from sqlalchemy.ext.declarative import declared_attr
from dogpile.cache.region import make_region
from dogpile.cache.api import NO_VALUE
def md5_key_mangler(key):
"""
Encodes SELECT queries (key) into md5 hashes
"""
if key.startswith('SELECT '):
key = hashlib.md5(key.encode('ascii')).hexdigest()
return key
def memoize(obj):
"""
Local cache of the function return value
"""
cache = obj.cache = {}
@functools.wraps(obj)
def memoizer(*args, **kwargs):
key = str(args) + str(kwargs)
if key not in cache:
cache[key] = obj(*args, **kwargs)
return cache[key]
return memoizer
cache_config = {
'backend': 'dogpile.cache.memory',
'expiration_time': 60,
}
regions = dict(
default=make_region(key_mangler=md5_key_mangler).configure(**cache_config)
)
class CachingQuery(BaseQuery):
"""
A Query subclass which optionally loads full results from a dogpile
cache region.
"""
def __init__(self, regions, entities, *args, **kw):
self.cache_regions = regions
BaseQuery.__init__(self, entities=entities, *args, **kw)
def __iter__(self):
"""
override __iter__ to pull results from dogpile
if particular attributes have been configured.
"""
if hasattr(self, '_cache_region'):
return self.get_value(createfunc=lambda: list(BaseQuery.__iter__(self)))
else:
return BaseQuery.__iter__(self)
def _get_cache_plus_key(self):
"""
Return a cache region plus key.
"""
dogpile_region = self.cache_regions[self._cache_region.region]
if self._cache_region.cache_key:
key = self._cache_region.cache_key
else:
key = _key_from_query(self)
return dogpile_region, key
def invalidate(self):
"""
Invalidate the cache value represented by this Query.
"""
dogpile_region, cache_key = self._get_cache_plus_key()
dogpile_region.delete(cache_key)
def get_value(self, merge=True, createfunc=None,
expiration_time=None, ignore_expiration=False):
"""
Return the value from the cache for this query.
Raise KeyError if no value present and no
createfunc specified.
"""
dogpile_region, cache_key = self._get_cache_plus_key()
assert not ignore_expiration or not createfunc, \
"Can't ignore expiration and also provide createfunc"
if ignore_expiration or not createfunc:
cached_value = dogpile_region.get(cache_key,
expiration_time=expiration_time,
ignore_expiration=ignore_expiration)
else:
cached_value = dogpile_region.get_or_create(
cache_key,
createfunc,
expiration_time=expiration_time)
if cached_value is NO_VALUE:
raise KeyError(cache_key)
if merge:
cached_value = self.merge_result(cached_value, load=False)
return cached_value
def set_value(self, value):
"""
Set the value in the cache for this query.
"""
dogpile_region, cache_key = self._get_cache_plus_key()
dogpile_region.set(cache_key, value)
def query_callable(regions, query_cls=CachingQuery):
return functools.partial(query_cls, regions)
def _key_from_query(query, qualifier=None):
"""
Given a Query, create a cache key.
"""
stmt = query.with_labels().statement
compiled = stmt.compile()
params = compiled.params
return " ".join(
[str(compiled)] +
[str(params[k]) for k in sorted(params)])
class FromCache(MapperOption):
"""Specifies that a Query should load results from a cache."""
propagate_to_loaders = False
def __init__(self, region="default", cache_key=None):
"""Construct a new FromCache.
:param region: the cache region. Should be a
region configured in the dictionary of dogpile
regions.
:param cache_key: optional. A string cache key
that will serve as the key to the query. Use this
if your query has a huge amount of parameters (such
as when using in_()) which correspond more simply to
some other identifier.
"""
self.region = region
self.cache_key = cache_key
def process_query(self, query):
"""Process a Query during normal loading operation."""
query._cache_region = self
class Cache(object):
def __init__(self, model, regions, label):
self.model = model
self.regions = regions
self.label = label
# allow custom pk or default to 'id'
self.pk = getattr(model, 'cache_pk', 'id')
def get(self, pk):
"""
Equivalent to the Model.query.get(pk) but using cache
"""
return self.model.query.options(self.from_cache(pk=pk)).get(pk)
def filter(self, order_by='asc', offset=None, limit=None, **kwargs):
"""
Retrieve all the objects ids then pull them independently from cache.
kwargs accepts one attribute filter, mainly for relationship pulling.
offset and limit allow pagination, order by for sorting (asc/desc).
"""
query_kwargs = {}
if kwargs:
if len(kwargs) > 1:
raise TypeError('filter accept only one attribute for filtering')
key, value = kwargs.items()[0]
if key not in self._columns():
raise TypeError('%s does not have an attribute %s' % self, key)
query_kwargs[key] = value
cache_key = self._cache_key(**kwargs)
pks = self.regions[self.label].get(cache_key)
if pks is NO_VALUE:
pks = [o.id for o in self.model.query.filter_by(**kwargs)\
.with_entities(getattr(self.model, self.pk))]
self.regions[self.label].set(cache_key, pks)
if order_by == 'desc':
pks.reverse()
if offset is not None:
pks = pks[pks:]
if limit is not None:
pks = pks[:limit]
keys = [self._cache_key(id) for id in pks]
for pos, obj in enumerate(self.regions[self.label].get_multi(keys)):
if obj is NO_VALUE:
yield self.get(pks[pos])
else:
yield obj[0]
def flush(self, key):
"""
flush the given key from dogpile.cache
"""
self.regions[self.label].delete(key)
@memoize
def _columns(self):
return [c.name for c in self.model.__table__.columns if c.name != self.pk]
@memoize
def from_cache(self, cache_key=None, pk=None):
"""
build the from cache option object the the given object
"""
if pk:
cache_key = self._cache_key(pk)
# if cache_key is none, the mangler will generate a MD5 from the query
return FromCache(self.label, cache_key)
@memoize
def _cache_key(self, pk="all", **kwargs):
"""
Generate a key as query
format: '<tablename>.<column>[<value>]'
'user.id[all]': all users
'address.user_id=4[all]': all address linked to user id 4
'user.id[4]': user with id=4
"""
q_filter = "".join("%s=%s" % (k, v) for k, v in kwargs.items()) or self.pk
return "%s.%s[%s]" % (self.model.__tablename__, q_filter, pk)
def _flush_all(self, obj):
for column in self._columns():
added, unchanged, deleted = get_history(obj, column)
for value in list(deleted) + list(added):
self.flush(self._cache_key(**{column: value}))
# flush "all" listing
self.flush(self._cache_key())
# flush the object
self.flush(self._cache_key(getattr(obj, self.pk)))
class CacheableMixin(object):
@declared_attr
def cache(cls):
"""
Add the cache features to the model
"""
return Cache(cls, cls.cache_regions, cls.cache_label)
@staticmethod
def _flush_event(mapper, connection, target):
"""
Called on object modification to flush cache of dependencies
"""
target.cache._flush_all(target)
@classmethod
def __declare_last__(cls):
"""
Auto clean the caches, including listings possibly associated with
this instance, on delete, update and insert.
"""
event.listen(cls, 'before_delete', cls._flush_event)
event.listen(cls, 'before_update', cls._flush_event)
event.listen(cls, 'before_insert', cls._flush_event)

Flask-SQLAlchemy Caching

The following gist is an extract of the article Flask-SQLAlchemy Caching. It allows automated simple cache query and invalidation of cache relations through event among other features.

Usage

retrieve one object

# pulling one User object
user = User.query.get(1)
# pulling one User object from cache
user = User.cache.get(1)

retrieve a list of object

# user is the object we pulled earlier (either from cache or not)
# Using the standard query (database hit)
email_addresses = EmailAddress.query.filter(user_id=1)
# pulling the same results from cache
email_addresses = EmailAddress.cache.filter(user_id=1)

Install on your model

from caching import CacheableMixin, query_callable, regions

class User(db.Model, CacheableMixin):
    cache_label = "default" # region's label to use
    cache_regions = regions # regions to store cache
    # Query handeling dogpile caching
    query_class = query_callable(regions)
    
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True)
    email = db.Column(db.String(120), unique=True)
Flask==0.10.1
Flask-SQLAlchemy==1.0
Jinja2==2.7.3
MarkupSafe==0.23
SQLAlchemy==0.9.6
Werkzeug==0.9.6
dogpile.cache==0.5.4
dogpile.core==0.4.1
itsdangerous==0.24
wsgiref==0.1.2
@l-vincent-l
Copy link

This won't work with a column with a name attribute different from the attribute name.
You call get_history with the a column name and not the attribute key.
You can inspect the model to get attribute list.

diff --git a/caching.py b/caching.py
index c59bee3..2bd3f56 100644
--- a/caching.py
+++ b/caching.py
@@ -5,7 +5,7 @@ import functools
 import hashlib

 from flask.ext.sqlalchemy import BaseQuery
-from sqlalchemy import event, select
+from sqlalchemy import event, select, inspect
 from sqlalchemy.orm.interfaces import MapperOption
 from sqlalchemy.orm.attributes import get_history
 from sqlalchemy.ext.declarative import declared_attr
@@ -199,7 +199,7 @@ class Cache(object):
             if len(kwargs) > 1:
                 raise TypeError('filter accept only one attribute for filtering')
             key, value = kwargs.items()[0]
-            if key not in self._columns():
+            if key not in self._attrs():
                 raise TypeError('%s does not have an attribute %s' % self, key)
             query_kwargs[key] = value

@@ -236,8 +236,8 @@ class Cache(object):


     @memoize
-    def _columns(self):
-        return [c.name for c in self.model.__table__.columns if c.name != self.pk]
+    def _attrs(self):
+        return [a.key for a in inspect(self.model).attrs if a.key != self.pk]


     @memoize
@@ -266,10 +266,10 @@ class Cache(object):


     def _flush_all(self, obj):
-        for column in self._columns():
-            added, unchanged, deleted = get_history(obj, column)
+        for attr in self._attrs():
+            added, unchanged, deleted = get_history(obj, attr)
             for value in list(deleted) + list(added):
-                self.flush(self._cache_key(**{column: value}))
+                self.flush(self._cache_key(**{attr: value}))
         # flush "all" listing
         self.flush(self._cache_key())
         # flush the object

@eduardoluizgs
Copy link

It´s work fine! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment