Skip to content

Instantly share code, notes, and snippets.

@zzzeek
Last active July 12, 2016 15:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zzzeek/43d10d34993126f074a5 to your computer and use it in GitHub Desktop.
Save zzzeek/43d10d34993126f074a5 to your computer and use it in GitHub Desktop.
# here's a variant that adds a decorator like that of
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/BakedQuery
from sqlalchemy.orm.query import QueryContext
class BakedQuery(object):
"""an object that can produce a 'baked' Query, that is one where
its ultimately generated SQL string is cached based on how the query
has been constructed.
"""
_bakery = {}
def __init__(self, fn, args=()):
if args:
self._cache_key = tuple(args)
else:
self._cache_key = ()
self.query = fn(*args)
self._params = {}
self._update_cache_key(fn)
self.steps = []
def _update_cache_key(self, fn):
self._cache_key += (fn.func_code.co_filename,
fn.func_code.co_firstlineno)
@classmethod
def baked(cls, fn):
def decorate(*args):
return BakedQuery(fn, args)
return decorate
def bake(self, fn):
self._update_cache_key(fn)
self.steps.append(fn)
return self
def _bake(self):
query = self.query
for step in self.steps:
query = step(query)
context = query._compile_context()
del context.session
del context.query
self._bakery[self._cache_key] = context
def params(self, **kw):
self._params.update(kw)
return self
def __iter__(self):
if self._cache_key not in self._bakery:
self._bake()
query = self.query
query._execution_options = query._execution_options.union(
{"compiled_cache": self._bakery}
)
baked_context = self._bakery[self._cache_key]
context = QueryContext.__new__(QueryContext)
context.__dict__.update(baked_context.__dict__)
context.query = query
context.session = query.session
context.attributes = context.attributes.copy()
context.statement.use_labels = True
if query._autoflush and not query._populate_existing:
query.session._autoflush()
return query.params(self._params)._execute_and_instances(context)
def all(self):
return list(self)
import cProfile
import StringIO
import pstats
import contextlib
@contextlib.contextmanager
def profiled():
pr = cProfile.Profile()
pr.enable()
yield
pr.disable()
s = StringIO.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
#ps.print_stats()
print "total calls: ", ps.total_calls
print s.getvalue()
if __name__ == '__main__':
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class A(Base):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
bs = relationship("B")
class B(Base):
__tablename__ = 'b'
id = Column(Integer, primary_key=True)
a_id = Column(Integer, ForeignKey('a.id'))
e = create_engine("sqlite://")
Base.metadata.create_all(e)
sess = Session(e)
sess.add_all([
A(id=1, bs=[B(id=1), B(id=2)]),
A(id=2, bs=[B(id=3)]),
A(id=3, bs=[])
])
sess.commit()
#@BakedQuery.baked
#def go():
# return sess.query(A).filter(A.id == 3)
def make_query_uncached(filter_, join, order_by):
q = sess.query(A)
if filter_:
q = q.filter(A.id.in_([2, 3]))
if join:
q = q.join(A.bs)
if order_by:
q = q.order_by(A.id.desc())
return q
@BakedQuery.baked
def make_query_cached(filter_, join, order_by):
return make_query_uncached(filter_, join, order_by)
import random
def run_test(use_cache):
for i in range(1000):
filter_, join, order_by = random.randint(0, 1),\
random.randint(0, 1),\
random.randint(0, 1)
if use_cache:
q = make_query_cached(filter_, join, order_by)
else:
q = make_query_uncached(filter_, join, order_by)
assert_(q.all(), filter_, join, order_by)
def assert_(result, filter_, join, order_by):
ordered_ids = [a.id for a in result]
ids = set(ordered_ids)
if filter_:
assert 1 not in ids
else:
assert 1 in ids
if join:
assert 3 not in ids
else:
assert 3 in ids
if order_by:
assert list(reversed(sorted(ids))) == ordered_ids
else:
assert sorted(ids) == ordered_ids
print("Run test with no cache")
with profiled():
run_test(False)
print("Run test with cache")
with profiled():
run_test(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment