-
-
Save zzzeek/978d78ad70e97751dc5e to your computer and use it in GitHub Desktop.
examples of trying to cache a "Baked" query. see https://gist.github.com/zzzeek/43d10d34993126f074a5 for updated version.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# first, the weird one. We are trying to form a cache key for a Query by looking at stackframes, | |
# summing up the filename/line number and working it all out. This actually works in this | |
# example, including an extra "bake a step" thing that makes it much more complicated but | |
# allows us to bypass the overhead of expensive Query methods like join() - since after | |
# all this is all about Python function call overhead right? | |
# for whatever reason, the sys._current_frames() *doesnt give us the same answer each time*, | |
# for the same call stack. Why? Who knows, but this is a very valuable clue that this method | |
# is way too crazy. | |
from sqlalchemy.orm import query | |
import sys | |
import thread | |
class BakedQuery(query.Query): | |
_baked_cache = {} | |
_bake_steps = () | |
def _clone(self): | |
self = super(BakedQuery, self)._clone() | |
tback_key = self._key_from_tback() | |
new_cache_key = self._cache_key + tback_key | |
self = self._check_bake_steps(tback_key) | |
self._cache_key = new_cache_key | |
return self | |
def _key_from_tback(self): | |
"""Form a cache key against the current traceback. | |
""" | |
frame = sys._current_frames()[thread.get_ident()] | |
key = "" | |
while frame.f_back: | |
frame = frame.f_back | |
filename, lineno = frame.f_code.co_filename, frame.f_lineno | |
key += "%s%s" % ( | |
hash(filename), lineno | |
) | |
# print frame.f_code.co_filename, frame.f_lineno | |
if frame.f_code is self._stack_top: | |
break | |
return key | |
# the "bake step" idea is above and beyond the usual | |
# query caching. this lets us skip whole calls to various query | |
# methods if we've already called them once, such as query.join(). | |
def bake_step(self, fn): | |
if not self._bake_steps: | |
self._bake_steps = [] | |
tback_key = self._key_from_tback() | |
self = super(BakedQuery, self)._clone() | |
self._bake_steps.append(fn) | |
self._cache_key = self._cache_key + tback_key | |
return self | |
def _check_bake_steps(self, tback_key): | |
been_here_before_key = ('bhb', self._cache_key, tback_key) | |
been_here_before = been_here_before_key in self._baked_cache | |
if not been_here_before: | |
if self._bake_steps: | |
self = self._run_bake_steps() | |
self._baked_cache[been_here_before_key] = True | |
return self | |
def _run_bake_steps(self): | |
bake_steps = list(self._bake_steps) | |
self._bake_steps[:] = [] | |
for step in bake_steps: | |
self = step(self) | |
return self | |
@classmethod | |
def from_query(cls, query, validate=False): | |
# if validate=True, the BQ would work such that it compiles | |
# every time and asserts that the cached version matches up with | |
# what's expected. | |
if isinstance(query, BakedQuery): | |
# this can likely be made to work but the BakedQuery | |
# shouldn't be used this loosely. | |
raise Exception("bakedquery shouldn't be called recursively.") | |
# Here I'd probably want to make a new query class that makes | |
# use of any existing subclassing applied to the query given. | |
bq = BakedQuery.__new__(BakedQuery) | |
bq.__dict__ = query.__dict__.copy() | |
frame = sys._current_frames()[thread.get_ident()] | |
bq._stack_top = frame.f_back.f_code | |
tback_key = bq._key_from_tback() | |
bq._cache_key = tback_key | |
return bq | |
def _compile_context(self, **kw): | |
cache_key = self._cache_key | |
if cache_key not in self._baked_cache: | |
if self._bake_steps: | |
self = self._run_bake_steps() | |
#print "NO CACHE!" | |
baked_context = super(BakedQuery, self)._compile_context(**kw) | |
del baked_context.session | |
del baked_context.query | |
self._baked_cache[cache_key] = baked_context | |
else: | |
#print "CACHE!" | |
baked_context = self._baked_cache[cache_key] | |
context = query.QueryContext.__new__(query.QueryContext) | |
context.__dict__.update(baked_context.__dict__) | |
context.query = self | |
context.session = self.session | |
context.attributes = context.attributes.copy() | |
return context | |
def _execute_and_instances(self, querycontext): | |
self = self.execution_options(compiled_cache=self._baked_cache) | |
return super(BakedQuery, self)._execute_and_instances(querycontext) | |
import cProfile | |
import StringIO | |
import pstats | |
import contextlib | |
@contextlib.contextmanager | |
def profiled(): | |
pr = cProfile.Profile() | |
pr.enable() | |
yield | |
pr.disable() | |
s = StringIO.StringIO() | |
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') | |
#ps.print_stats() | |
print "total calls: ", ps.total_calls | |
print s.getvalue() | |
if __name__ == '__main__': | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
class A(Base): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
bs = relationship("B") | |
class B(Base): | |
__tablename__ = 'b' | |
id = Column(Integer, primary_key=True) | |
a_id = Column(Integer, ForeignKey('a.id')) | |
e = create_engine("sqlite://") | |
Base.metadata.create_all(e) | |
sess = Session(e) | |
sess.add_all([ | |
A(id=1, bs=[B(id=1), B(id=2)]), | |
A(id=2, bs=[B(id=3)]), | |
A(id=3, bs=[]) | |
]) | |
sess.commit() | |
def do_query(use_cache, use_bake_step, bake_order_also, filter_, join, order_by): | |
sess = Session(e) | |
q = sess.query(A) | |
if use_cache: | |
q = BakedQuery.from_query(q) | |
if filter_: | |
q = q.filter(A.id.in_([2, 3])) | |
if join: | |
if use_cache and use_bake_step: | |
# the point of bake_step is to save on the overhead | |
# of calling join() at all | |
q = q.bake_step(lambda q: q.join(A.bs)) | |
else: | |
q = q.join(A.bs) | |
if order_by: | |
if use_cache and bake_order_also: | |
q = q.bake_step(lambda q: q.order_by(A.id.desc())) | |
else: | |
q = q.order_by(A.id.desc()) | |
# some adjustments would be needed in subqueryload, | |
# as we didn't actually assemble the BakedQuery class into | |
# the session. I like the idea of keeping it out totally | |
# for normal queries. | |
# q = q.options(subqueryload(A.bs)) | |
return q | |
def assert_(result, filter_, join, order_by): | |
ordered_ids = [a.id for a in result] | |
ids = set(ordered_ids) | |
if filter_: | |
assert 1 not in ids | |
else: | |
assert 1 in ids | |
if join: | |
assert 3 not in ids | |
else: | |
assert 3 in ids | |
if order_by: | |
assert list(reversed(sorted(ids))) == ordered_ids | |
else: | |
assert sorted(ids) == ordered_ids | |
import random | |
def run_test(use_cache, use_bake_step, bake_order_also): | |
for i in range(1000): | |
filter_, join, order_by = random.randint(0, 1),\ | |
random.randint(0, 1),\ | |
random.randint(0, 1) | |
#print(filter_, join, order_by) | |
q = do_query( | |
use_cache, use_bake_step, bake_order_also, | |
filter_, join, order_by) | |
assert_(q.all(), filter_, join, order_by) | |
# we are creative with the stack keying so that | |
# an operation like this works, e.g. the original calling | |
# line is different | |
#do_query(True, False, False, True, False, False).all() | |
#do_query(True, False, False, True, False, False).all() | |
if True: #False: | |
print("Run test with no cache") | |
with profiled(): | |
run_test(False, False, False) | |
print("Run test with cache, but join() every time") | |
with profiled(): | |
run_test(True, False, False) | |
print("Run test with cache + bake the join() step") | |
with profiled(): | |
run_test(True, True, False) | |
print("Run test with cache + bake the join() and order_by() steps") | |
with profiled(): | |
run_test(True, True, True) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# second, the easy one! Everything you do with the query is in a lambda, | |
# we track it based on the filename/position of each lambda, we're done. It's simpler, | |
# and uses 100K fewer function calls than the fastest example with the other one. | |
# its just EW EXPLICIT ! EW! but you know what, if you're dealing with a fixed-state | |
# cached query, you are optimizing and you also need to be very aware that's what you're doing; | |
# the argument can be made that it *shouldn't* look like normal querying because it isnt. | |
from sqlalchemy.orm.query import QueryContext | |
class BakedQuery(object): | |
_bakery = {} | |
def __init__(self, fn): | |
self._cache_key = () | |
self._update_cache_key(fn) | |
self.query = fn() | |
self.steps = [] | |
def _update_cache_key(self, fn): | |
self._cache_key += (fn.func_code.co_filename, | |
fn.func_code.co_firstlineno) | |
def bake(self, fn): | |
self._update_cache_key(fn) | |
self.steps.append(fn) | |
def _bake(self): | |
query = self.query | |
for step in self.steps: | |
query = step(query) | |
context = query._compile_context() | |
del context.session | |
del context.query | |
self._bakery[self._cache_key] = context | |
def __iter__(self): | |
if self._cache_key not in self._bakery: | |
self._bake() | |
query = self.query | |
query._execution_options = query._execution_options.union( | |
{"compiled_cache": self._bakery} | |
) | |
baked_context = self._bakery[self._cache_key] | |
context = QueryContext.__new__(QueryContext) | |
context.__dict__.update(baked_context.__dict__) | |
context.query = query | |
context.session = query.session | |
context.attributes = context.attributes.copy() | |
context.statement.use_labels = True | |
if query._autoflush and not query._populate_existing: | |
query.session._autoflush() | |
return query._execute_and_instances(context) | |
def all(self): | |
return list(self) | |
import cProfile | |
import StringIO | |
import pstats | |
import contextlib | |
@contextlib.contextmanager | |
def profiled(): | |
pr = cProfile.Profile() | |
pr.enable() | |
yield | |
pr.disable() | |
s = StringIO.StringIO() | |
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') | |
#ps.print_stats() | |
print "total calls: ", ps.total_calls | |
print s.getvalue() | |
if __name__ == '__main__': | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
class A(Base): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
bs = relationship("B") | |
class B(Base): | |
__tablename__ = 'b' | |
id = Column(Integer, primary_key=True) | |
a_id = Column(Integer, ForeignKey('a.id')) | |
e = create_engine("sqlite://") | |
Base.metadata.create_all(e) | |
sess = Session(e) | |
sess.add_all([ | |
A(id=1, bs=[B(id=1), B(id=2)]), | |
A(id=2, bs=[B(id=3)]), | |
A(id=3, bs=[]) | |
]) | |
sess.commit() | |
def do_query_cached(filter_, join, order_by): | |
sess = Session(e) | |
baked = BakedQuery(lambda: sess.query(A)) | |
if filter_: | |
baked.bake(lambda q: q.filter(A.id.in_([2, 3]))) | |
if join: | |
baked.bake(lambda q: q.join(A.bs)) | |
if order_by: | |
baked.bake(lambda q: q.order_by(A.id.desc())) | |
return baked | |
def do_query_uncached(filter_, join, order_by): | |
sess = Session(e) | |
q = sess.query(A) | |
if filter_: | |
q = q.filter(A.id.in_([2, 3])) | |
if join: | |
q = q.join(A.bs) | |
if order_by: | |
q = q.order_by(A.id.desc()) | |
return q | |
def assert_(result, filter_, join, order_by): | |
ordered_ids = [a.id for a in result] | |
ids = set(ordered_ids) | |
if filter_: | |
assert 1 not in ids | |
else: | |
assert 1 in ids | |
if join: | |
assert 3 not in ids | |
else: | |
assert 3 in ids | |
if order_by: | |
assert list(reversed(sorted(ids))) == ordered_ids | |
else: | |
assert sorted(ids) == ordered_ids | |
import random | |
def run_test(use_cache): | |
for i in range(1000): | |
filter_, join, order_by = random.randint(0, 1),\ | |
random.randint(0, 1),\ | |
random.randint(0, 1) | |
#print(filter_, join, order_by) | |
if use_cache: | |
q = do_query_cached(filter_, join, order_by) | |
else: | |
q = do_query_uncached(filter_, join, order_by) | |
assert_(q.all(), filter_, join, order_by) | |
print("Run test with no cache") | |
with profiled(): | |
run_test(False) | |
print("Run test with cache") | |
with profiled(): | |
run_test(True) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# here's a variant that adds a decorator like that of | |
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/BakedQuery | |
from sqlalchemy.orm.query import QueryContext | |
class BakedQuery(object): | |
_bakery = {} | |
def __init__(self, fn, args=None): | |
if args: | |
self._cache_key = tuple(args) | |
else: | |
self._cache_key = () | |
self._update_cache_key(fn) | |
self.query = fn() | |
self.steps = [] | |
def _update_cache_key(self, fn): | |
self._cache_key += (fn.func_code.co_filename, | |
fn.func_code.co_firstlineno) | |
@classmethod | |
def baked(cls, fn): | |
def decorate(*args): | |
return BakedQuery(fn, args) | |
return decorate | |
def bake(self, fn): | |
self._update_cache_key(fn) | |
self.steps.append(fn) | |
def _bake(self): | |
query = self.query | |
for step in self.steps: | |
query = step(query) | |
context = query._compile_context() | |
del context.session | |
del context.query | |
self._bakery[self._cache_key] = context | |
def __iter__(self): | |
if self._cache_key not in self._bakery: | |
self._bake() | |
query = self.query | |
query._execution_options = query._execution_options.union( | |
{"compiled_cache": self._bakery} | |
) | |
baked_context = self._bakery[self._cache_key] | |
context = QueryContext.__new__(QueryContext) | |
context.__dict__.update(baked_context.__dict__) | |
context.query = query | |
context.session = query.session | |
context.attributes = context.attributes.copy() | |
context.statement.use_labels = True | |
if query._autoflush and not query._populate_existing: | |
query.session._autoflush() | |
return query._execute_and_instances(context) | |
def all(self): | |
return list(self) | |
if __name__ == '__main__': | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
class A(Base): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
bs = relationship("B") | |
class B(Base): | |
__tablename__ = 'b' | |
id = Column(Integer, primary_key=True) | |
a_id = Column(Integer, ForeignKey('a.id')) | |
e = create_engine("sqlite://") | |
Base.metadata.create_all(e) | |
sess = Session(e) | |
sess.add_all([ | |
A(id=1, bs=[B(id=1), B(id=2)]), | |
A(id=2, bs=[B(id=3)]), | |
A(id=3, bs=[]) | |
]) | |
sess.commit() | |
@BakedQuery.baked | |
def go(): | |
return sess.query(A).filter(A.id == 3) | |
b1 = go() | |
print b1.all() | |
print b1.all() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment