Skip to content

Instantly share code, notes, and snippets.

@zzzeek
Last active August 29, 2015 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zzzeek/978d78ad70e97751dc5e to your computer and use it in GitHub Desktop.
Save zzzeek/978d78ad70e97751dc5e to your computer and use it in GitHub Desktop.
examples of trying to cache a "Baked" query. see https://gist.github.com/zzzeek/43d10d34993126f074a5 for updated version.
# first, the weird one. We are trying to form a cache key for a Query by looking at stackframes,
# summing up the filename/line number and working it all out. This actually works in this
# example, including an extra "bake a step" thing that makes it much more complicated but
# allows us to bypass the overhead of expensive Query methods like join() - since after
# all this is all about Python function call overhead right?
# for whatever reason, the sys._current_frames() *doesnt give us the same answer each time*,
# for the same call stack. Why? Who knows, but this is a very valuable clue that this method
# is way too crazy.
from sqlalchemy.orm import query
import sys
import thread
class BakedQuery(query.Query):
_baked_cache = {}
_bake_steps = ()
def _clone(self):
self = super(BakedQuery, self)._clone()
tback_key = self._key_from_tback()
new_cache_key = self._cache_key + tback_key
self = self._check_bake_steps(tback_key)
self._cache_key = new_cache_key
return self
def _key_from_tback(self):
"""Form a cache key against the current traceback.
"""
frame = sys._current_frames()[thread.get_ident()]
key = ""
while frame.f_back:
frame = frame.f_back
filename, lineno = frame.f_code.co_filename, frame.f_lineno
key += "%s%s" % (
hash(filename), lineno
)
# print frame.f_code.co_filename, frame.f_lineno
if frame.f_code is self._stack_top:
break
return key
# the "bake step" idea is above and beyond the usual
# query caching. this lets us skip whole calls to various query
# methods if we've already called them once, such as query.join().
def bake_step(self, fn):
if not self._bake_steps:
self._bake_steps = []
tback_key = self._key_from_tback()
self = super(BakedQuery, self)._clone()
self._bake_steps.append(fn)
self._cache_key = self._cache_key + tback_key
return self
def _check_bake_steps(self, tback_key):
been_here_before_key = ('bhb', self._cache_key, tback_key)
been_here_before = been_here_before_key in self._baked_cache
if not been_here_before:
if self._bake_steps:
self = self._run_bake_steps()
self._baked_cache[been_here_before_key] = True
return self
def _run_bake_steps(self):
bake_steps = list(self._bake_steps)
self._bake_steps[:] = []
for step in bake_steps:
self = step(self)
return self
@classmethod
def from_query(cls, query, validate=False):
# if validate=True, the BQ would work such that it compiles
# every time and asserts that the cached version matches up with
# what's expected.
if isinstance(query, BakedQuery):
# this can likely be made to work but the BakedQuery
# shouldn't be used this loosely.
raise Exception("bakedquery shouldn't be called recursively.")
# Here I'd probably want to make a new query class that makes
# use of any existing subclassing applied to the query given.
bq = BakedQuery.__new__(BakedQuery)
bq.__dict__ = query.__dict__.copy()
frame = sys._current_frames()[thread.get_ident()]
bq._stack_top = frame.f_back.f_code
tback_key = bq._key_from_tback()
bq._cache_key = tback_key
return bq
def _compile_context(self, **kw):
cache_key = self._cache_key
if cache_key not in self._baked_cache:
if self._bake_steps:
self = self._run_bake_steps()
#print "NO CACHE!"
baked_context = super(BakedQuery, self)._compile_context(**kw)
del baked_context.session
del baked_context.query
self._baked_cache[cache_key] = baked_context
else:
#print "CACHE!"
baked_context = self._baked_cache[cache_key]
context = query.QueryContext.__new__(query.QueryContext)
context.__dict__.update(baked_context.__dict__)
context.query = self
context.session = self.session
context.attributes = context.attributes.copy()
return context
def _execute_and_instances(self, querycontext):
self = self.execution_options(compiled_cache=self._baked_cache)
return super(BakedQuery, self)._execute_and_instances(querycontext)
import cProfile
import StringIO
import pstats
import contextlib
@contextlib.contextmanager
def profiled():
pr = cProfile.Profile()
pr.enable()
yield
pr.disable()
s = StringIO.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
#ps.print_stats()
print "total calls: ", ps.total_calls
print s.getvalue()
if __name__ == '__main__':
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class A(Base):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
bs = relationship("B")
class B(Base):
__tablename__ = 'b'
id = Column(Integer, primary_key=True)
a_id = Column(Integer, ForeignKey('a.id'))
e = create_engine("sqlite://")
Base.metadata.create_all(e)
sess = Session(e)
sess.add_all([
A(id=1, bs=[B(id=1), B(id=2)]),
A(id=2, bs=[B(id=3)]),
A(id=3, bs=[])
])
sess.commit()
def do_query(use_cache, use_bake_step, bake_order_also, filter_, join, order_by):
sess = Session(e)
q = sess.query(A)
if use_cache:
q = BakedQuery.from_query(q)
if filter_:
q = q.filter(A.id.in_([2, 3]))
if join:
if use_cache and use_bake_step:
# the point of bake_step is to save on the overhead
# of calling join() at all
q = q.bake_step(lambda q: q.join(A.bs))
else:
q = q.join(A.bs)
if order_by:
if use_cache and bake_order_also:
q = q.bake_step(lambda q: q.order_by(A.id.desc()))
else:
q = q.order_by(A.id.desc())
# some adjustments would be needed in subqueryload,
# as we didn't actually assemble the BakedQuery class into
# the session. I like the idea of keeping it out totally
# for normal queries.
# q = q.options(subqueryload(A.bs))
return q
def assert_(result, filter_, join, order_by):
ordered_ids = [a.id for a in result]
ids = set(ordered_ids)
if filter_:
assert 1 not in ids
else:
assert 1 in ids
if join:
assert 3 not in ids
else:
assert 3 in ids
if order_by:
assert list(reversed(sorted(ids))) == ordered_ids
else:
assert sorted(ids) == ordered_ids
import random
def run_test(use_cache, use_bake_step, bake_order_also):
for i in range(1000):
filter_, join, order_by = random.randint(0, 1),\
random.randint(0, 1),\
random.randint(0, 1)
#print(filter_, join, order_by)
q = do_query(
use_cache, use_bake_step, bake_order_also,
filter_, join, order_by)
assert_(q.all(), filter_, join, order_by)
# we are creative with the stack keying so that
# an operation like this works, e.g. the original calling
# line is different
#do_query(True, False, False, True, False, False).all()
#do_query(True, False, False, True, False, False).all()
if True: #False:
print("Run test with no cache")
with profiled():
run_test(False, False, False)
print("Run test with cache, but join() every time")
with profiled():
run_test(True, False, False)
print("Run test with cache + bake the join() step")
with profiled():
run_test(True, True, False)
print("Run test with cache + bake the join() and order_by() steps")
with profiled():
run_test(True, True, True)
# second, the easy one! Everything you do with the query is in a lambda,
# we track it based on the filename/position of each lambda, we're done. It's simpler,
# and uses 100K fewer function calls than the fastest example with the other one.
# its just EW EXPLICIT ! EW! but you know what, if you're dealing with a fixed-state
# cached query, you are optimizing and you also need to be very aware that's what you're doing;
# the argument can be made that it *shouldn't* look like normal querying because it isnt.
from sqlalchemy.orm.query import QueryContext
class BakedQuery(object):
_bakery = {}
def __init__(self, fn):
self._cache_key = ()
self._update_cache_key(fn)
self.query = fn()
self.steps = []
def _update_cache_key(self, fn):
self._cache_key += (fn.func_code.co_filename,
fn.func_code.co_firstlineno)
def bake(self, fn):
self._update_cache_key(fn)
self.steps.append(fn)
def _bake(self):
query = self.query
for step in self.steps:
query = step(query)
context = query._compile_context()
del context.session
del context.query
self._bakery[self._cache_key] = context
def __iter__(self):
if self._cache_key not in self._bakery:
self._bake()
query = self.query
query._execution_options = query._execution_options.union(
{"compiled_cache": self._bakery}
)
baked_context = self._bakery[self._cache_key]
context = QueryContext.__new__(QueryContext)
context.__dict__.update(baked_context.__dict__)
context.query = query
context.session = query.session
context.attributes = context.attributes.copy()
context.statement.use_labels = True
if query._autoflush and not query._populate_existing:
query.session._autoflush()
return query._execute_and_instances(context)
def all(self):
return list(self)
import cProfile
import StringIO
import pstats
import contextlib
@contextlib.contextmanager
def profiled():
pr = cProfile.Profile()
pr.enable()
yield
pr.disable()
s = StringIO.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
#ps.print_stats()
print "total calls: ", ps.total_calls
print s.getvalue()
if __name__ == '__main__':
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class A(Base):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
bs = relationship("B")
class B(Base):
__tablename__ = 'b'
id = Column(Integer, primary_key=True)
a_id = Column(Integer, ForeignKey('a.id'))
e = create_engine("sqlite://")
Base.metadata.create_all(e)
sess = Session(e)
sess.add_all([
A(id=1, bs=[B(id=1), B(id=2)]),
A(id=2, bs=[B(id=3)]),
A(id=3, bs=[])
])
sess.commit()
def do_query_cached(filter_, join, order_by):
sess = Session(e)
baked = BakedQuery(lambda: sess.query(A))
if filter_:
baked.bake(lambda q: q.filter(A.id.in_([2, 3])))
if join:
baked.bake(lambda q: q.join(A.bs))
if order_by:
baked.bake(lambda q: q.order_by(A.id.desc()))
return baked
def do_query_uncached(filter_, join, order_by):
sess = Session(e)
q = sess.query(A)
if filter_:
q = q.filter(A.id.in_([2, 3]))
if join:
q = q.join(A.bs)
if order_by:
q = q.order_by(A.id.desc())
return q
def assert_(result, filter_, join, order_by):
ordered_ids = [a.id for a in result]
ids = set(ordered_ids)
if filter_:
assert 1 not in ids
else:
assert 1 in ids
if join:
assert 3 not in ids
else:
assert 3 in ids
if order_by:
assert list(reversed(sorted(ids))) == ordered_ids
else:
assert sorted(ids) == ordered_ids
import random
def run_test(use_cache):
for i in range(1000):
filter_, join, order_by = random.randint(0, 1),\
random.randint(0, 1),\
random.randint(0, 1)
#print(filter_, join, order_by)
if use_cache:
q = do_query_cached(filter_, join, order_by)
else:
q = do_query_uncached(filter_, join, order_by)
assert_(q.all(), filter_, join, order_by)
print("Run test with no cache")
with profiled():
run_test(False)
print("Run test with cache")
with profiled():
run_test(True)
# here's a variant that adds a decorator like that of
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/BakedQuery
from sqlalchemy.orm.query import QueryContext
class BakedQuery(object):
_bakery = {}
def __init__(self, fn, args=None):
if args:
self._cache_key = tuple(args)
else:
self._cache_key = ()
self._update_cache_key(fn)
self.query = fn()
self.steps = []
def _update_cache_key(self, fn):
self._cache_key += (fn.func_code.co_filename,
fn.func_code.co_firstlineno)
@classmethod
def baked(cls, fn):
def decorate(*args):
return BakedQuery(fn, args)
return decorate
def bake(self, fn):
self._update_cache_key(fn)
self.steps.append(fn)
def _bake(self):
query = self.query
for step in self.steps:
query = step(query)
context = query._compile_context()
del context.session
del context.query
self._bakery[self._cache_key] = context
def __iter__(self):
if self._cache_key not in self._bakery:
self._bake()
query = self.query
query._execution_options = query._execution_options.union(
{"compiled_cache": self._bakery}
)
baked_context = self._bakery[self._cache_key]
context = QueryContext.__new__(QueryContext)
context.__dict__.update(baked_context.__dict__)
context.query = query
context.session = query.session
context.attributes = context.attributes.copy()
context.statement.use_labels = True
if query._autoflush and not query._populate_existing:
query.session._autoflush()
return query._execute_and_instances(context)
def all(self):
return list(self)
if __name__ == '__main__':
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class A(Base):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
bs = relationship("B")
class B(Base):
__tablename__ = 'b'
id = Column(Integer, primary_key=True)
a_id = Column(Integer, ForeignKey('a.id'))
e = create_engine("sqlite://")
Base.metadata.create_all(e)
sess = Session(e)
sess.add_all([
A(id=1, bs=[B(id=1), B(id=2)]),
A(id=2, bs=[B(id=3)]),
A(id=3, bs=[])
])
sess.commit()
@BakedQuery.baked
def go():
return sess.query(A).filter(A.id == 3)
b1 = go()
print b1.all()
print b1.all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment