Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
"""Modifies the App Engine datastore to support local caching of entities.
This is achieved by monkeypatching google.appengine.ext.db to recognise model
classes that should be cached and store them locally for the duration of a
single page request.
Note that only datastore gets (and anything that relies on them, such as
ReferenceProperty fetches) are cached; queries will neither return cached
entities nor update the cache.
To use, wrap your WSGI application in an instance of CacheSession, and
modify any models that you expect to be fetched more than once per request to
extend CachedModel instead of db.Model.
from google.appengine.api import datastore
from google.appengine.ext import db
def splitByCond(l, func):
a = list()
b = list()
for item in l:
if func(item):
return (a,b)
def joinByCond(l, a, b, func):
ret = list()
ia = iter(a)
ib = iter(b)
for item in l:
if func(item):
class CachedModel(db.Model):
"""Any class that implements this will automatically have entities cached."""
def put(self):
return put(self)
def delete(self):
self._entity = None
# Save the functions and classes we'll patch so we can refer to them
db_get = db.get
db_put = db.put
db_delete = db.delete
# Stores the current datastore cache
_current_session = None
def get(keys):
if not _current_session or datastore._CurrentTransactionKey():
return db_get(keys)
keys, multiple = datastore.NormalizeAndTypeCheckKeys(keys)
# Split into cached and uncached
cond_func = lambda x: x in _current_session.cache
cached_keys, uncached_keys = splitByCond(keys, cond_func)
cached_models = [_current_session.cache[x] for x in cached_keys]
# Fetch uncached
if uncached_keys:
fetched_models = db_get(uncached_keys)
_current_session.cached_gets += 1
fetched_models = []
# Update stats
_current_session.hit_count += len(cached_models)
_current_session.miss_count += sum(int(isinstance(x, CachedModel))
for x in fetched_models)
_current_session.total_gets += 1
# Construct return list
ret = list(joinByCond(keys, cached_models, fetched_models, cond_func))
# Update cache
_current_session.cache.update((x.key(), x) for x in fetched_models
if isinstance(x, CachedModel))
if multiple:
return ret
return ret[0]
def put(models):
if not _current_session:
return db_put(models)
models, multiple = datastore.NormalizeAndTypeCheck(models, db.Model)
keys = db_put(models)
if not datastore._CurrentTransactionKey():
_current_session.cache.update((k, v) for k, v in zip(keys, models)
if isinstance(v, CachedModel))
# In transactions, delete from the cache, since we don't know if it'll be
# committed or rolled back.
for k in keys:
if k in _current_session.cache:
del _current_session.cache[k]
if multiple:
return keys
return keys[0]
def delete(models):
if not _current_session: return
models_or_keys, multiple = datastore.NormalizeAndTypeCheck(
models, (db.Model, db.Key, basestring))
for model_or_key in models_or_keys:
if isinstance(model_or_key, CachedModel):
k = model_or_key.key()
elif isinstance(model_or_key, basestring):
k = db.Key(model_or_key)
k = model_or_key
if k in _current_session.cache:
del _current_session.cache[k]
# Add in our monkeypatches
db.get = get
db.put = put
db.delete = delete
class CacheSession(object):
def __init__(self, wrapped):
self.wrapped = wrapped
self.cache = {}
self.hit_count = 0
self.miss_count = 0
self.total_gets = 0
self.cached_gets = 0
def __call__(self, environ, start_response):
global _current_session
_current_session = self
return self.wrapped(environ, start_response)
_current_session = None
hit_rate = self.hit_count / float(self.hit_count + self.miss_count)
get_hit_rate = self.cached_gets / float(self.total_gets)
"datastore_cache saved %d/%d entity fetches (%d hit rate), "
"%d/%d requests (%d hit rate)",
self.hit_count, self.hit_count+self.miss_count, hit_rate,
self.cached_gets, self.total_gets, get_hit_rate)
def getStats(self):
return {
'hits': self.hit_count,
'misses': self.miss_count,
'gets': self.total_gets,
'cached_gets': self.cached_gets,
import os
import datastore_cache
import unittest
from google.appengine.api import apiproxy_stub_map
from google.appengine.api import datastore_file_stub
from google.appengine.ext import db
class Foo(datastore_cache.CachedModel):
one = db.IntegerProperty()
class Bar(db.Model):
ref = db.ReferenceProperty(Foo)
class DatastoreCacheTest(unittest.TestCase):
def setUp(self):
os.environ['APPLICATION_ID'] = 'test'
apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
datastore = datastore_file_stub.DatastoreFileStub('test', None, None)
apiproxy_stub_map.apiproxy.RegisterStub('datastore_v3', datastore)
def testOutsideSession(self):
self.failUnlessEqual(datastore_cache._current_session, None)
# Single put
foo = Foo(one=1)
foo_id = foo.put()
# Multiple put = 2
bar = Bar(ref=foo)
foo_id, bar_id = db.put([foo, bar])
# Single get
self.failUnlessEqual(Foo.get(foo_id).one, 2)
# Multiple get
self.failUnlessEqual(len(db.get([foo_id, bar_id])), 2)
# Single delete
baz = Foo(one=3)
# Multiple delete
db.delete([foo, bar])
def testCacheSession(self):
class MyException(Exception): pass
def app(environ, start_response):
self.failIfEqual(datastore_cache._current_session, None)
return 'app'
def failApp(environ, start_response):
self.failIfEqual(datastore_cache._current_session, None)
raise MyException()
self.failUnlessEqual(datastore_cache._current_session, None)
session = datastore_cache.CacheSession(app)
self.failUnlessEqual(session(None, None), 'app')
self.failUnlessEqual(datastore_cache._current_session, None)
session = datastore_cache.CacheSession(failApp)
self.failUnlessRaises(MyException, session, None, None)
self.failUnlessEqual(datastore_cache._current_session, None)
def testCache(self):
session = datastore_cache.CacheSession(lambda x,y: [])
datastore_cache._current_session = session
# Single and multiple put of cached entities
foo = Foo(one=1)
self.failUnless(foo.key() in session.cache)
foo2 = Foo(one=2)
db.put([foo, foo2])
self.failUnless(foo2.key() in session.cache)
# Put of uncached entities
bar = Bar(ref=foo)
self.failIf(bar.key() in session.cache)
bar2 = Bar(ref=foo2)
db.put([bar, bar2])
self.failIf(bar2.key() in session.cache)
{'hits': 0, 'misses': 0, 'gets': 0, 'cached_gets': 0})
# Cache hit
session.cache[foo.key()].test = 'test'
self.failUnlessEqual(db.get(foo.key()).test, 'test')
{'hits': 1, 'misses': 0, 'gets': 1, 'cached_gets': 1})
# Cache miss
del session.cache[foo2.key()]
got = db.get([foo.key(), foo2.key()])
self.failUnlessEqual(got[0], foo)
self.failIfEqual(got[1], foo2)
gotkeys = [x.key() for x in got]
self.failUnlessEqual(gotkeys, [foo.key(), foo2.key()])
self.failUnless(foo2.key() in session.cache)
{'hits': 2, 'misses': 1, 'gets': 2, 'cached_gets': 1})
# Deletion
foo2_key = foo2.key()
self.failIf(foo2_key in session.cache)
# Reference property lookup
bar = Bar.get(bar.key())
self.failUnlessEqual(bar.ref, foo)
{'hits': 3, 'misses': 1, 'gets': 4, 'cached_gets': 2})
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment