Last active
July 12, 2016 15:44
-
-
Save zzzeek/43d10d34993126f074a5 to your computer and use it in GitHub Desktop.
baked query decorator + builder, see https://bitbucket.org/zzzeek/sqlalchemy/issue/3054/new-idea-for-bake
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# here's a variant that adds a decorator like that of | |
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/BakedQuery | |
from sqlalchemy.orm.query import QueryContext | |
class BakedQuery(object): | |
"""an object that can produce a 'baked' Query, that is one where | |
its ultimately generated SQL string is cached based on how the query | |
has been constructed. | |
""" | |
_bakery = {} | |
def __init__(self, fn, args=()): | |
if args: | |
self._cache_key = tuple(args) | |
else: | |
self._cache_key = () | |
self.query = fn(*args) | |
self._params = {} | |
self._update_cache_key(fn) | |
self.steps = [] | |
def _update_cache_key(self, fn): | |
self._cache_key += (fn.func_code.co_filename, | |
fn.func_code.co_firstlineno) | |
@classmethod | |
def baked(cls, fn): | |
def decorate(*args): | |
return BakedQuery(fn, args) | |
return decorate | |
def bake(self, fn): | |
self._update_cache_key(fn) | |
self.steps.append(fn) | |
return self | |
def _bake(self): | |
query = self.query | |
for step in self.steps: | |
query = step(query) | |
context = query._compile_context() | |
del context.session | |
del context.query | |
self._bakery[self._cache_key] = context | |
def params(self, **kw): | |
self._params.update(kw) | |
return self | |
def __iter__(self): | |
if self._cache_key not in self._bakery: | |
self._bake() | |
query = self.query | |
query._execution_options = query._execution_options.union( | |
{"compiled_cache": self._bakery} | |
) | |
baked_context = self._bakery[self._cache_key] | |
context = QueryContext.__new__(QueryContext) | |
context.__dict__.update(baked_context.__dict__) | |
context.query = query | |
context.session = query.session | |
context.attributes = context.attributes.copy() | |
context.statement.use_labels = True | |
if query._autoflush and not query._populate_existing: | |
query.session._autoflush() | |
return query.params(self._params)._execute_and_instances(context) | |
def all(self): | |
return list(self) | |
import cProfile | |
import StringIO | |
import pstats | |
import contextlib | |
@contextlib.contextmanager | |
def profiled(): | |
pr = cProfile.Profile() | |
pr.enable() | |
yield | |
pr.disable() | |
s = StringIO.StringIO() | |
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') | |
#ps.print_stats() | |
print "total calls: ", ps.total_calls | |
print s.getvalue() | |
if __name__ == '__main__': | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
class A(Base): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
bs = relationship("B") | |
class B(Base): | |
__tablename__ = 'b' | |
id = Column(Integer, primary_key=True) | |
a_id = Column(Integer, ForeignKey('a.id')) | |
e = create_engine("sqlite://") | |
Base.metadata.create_all(e) | |
sess = Session(e) | |
sess.add_all([ | |
A(id=1, bs=[B(id=1), B(id=2)]), | |
A(id=2, bs=[B(id=3)]), | |
A(id=3, bs=[]) | |
]) | |
sess.commit() | |
#@BakedQuery.baked | |
#def go(): | |
# return sess.query(A).filter(A.id == 3) | |
def make_query_uncached(filter_, join, order_by): | |
q = sess.query(A) | |
if filter_: | |
q = q.filter(A.id.in_([2, 3])) | |
if join: | |
q = q.join(A.bs) | |
if order_by: | |
q = q.order_by(A.id.desc()) | |
return q | |
@BakedQuery.baked | |
def make_query_cached(filter_, join, order_by): | |
return make_query_uncached(filter_, join, order_by) | |
import random | |
def run_test(use_cache): | |
for i in range(1000): | |
filter_, join, order_by = random.randint(0, 1),\ | |
random.randint(0, 1),\ | |
random.randint(0, 1) | |
if use_cache: | |
q = make_query_cached(filter_, join, order_by) | |
else: | |
q = make_query_uncached(filter_, join, order_by) | |
assert_(q.all(), filter_, join, order_by) | |
def assert_(result, filter_, join, order_by): | |
ordered_ids = [a.id for a in result] | |
ids = set(ordered_ids) | |
if filter_: | |
assert 1 not in ids | |
else: | |
assert 1 in ids | |
if join: | |
assert 3 not in ids | |
else: | |
assert 3 in ids | |
if order_by: | |
assert list(reversed(sorted(ids))) == ordered_ids | |
else: | |
assert sorted(ids) == ordered_ids | |
print("Run test with no cache") | |
with profiled(): | |
run_test(False) | |
print("Run test with cache") | |
with profiled(): | |
run_test(True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment