Skip to content

Instantly share code, notes, and snippets.

@ewhauser
Created August 23, 2022 17:10
Show Gist options
  • Save ewhauser/139923cb853500d406cc9b2c98763dce to your computer and use it in GitHub Desktop.
Save ewhauser/139923cb853500d406cc9b2c98763dce to your computer and use it in GitHub Desktop.
import inspect
import itertools
from python.nplusone.core import signals
from sqlalchemy.engine import ScalarResult
from sqlalchemy.orm import attributes, loading, query, strategies
def to_key(instance):
model = type(instance)
return ":".join(
[model.__name__]
+ [
format(instance.__dict__.get(key.key)) # Avoid recursion on __get__
for key in get_primary_keys(model)
]
)
def get_primary_keys(model):
mapper = model.__mapper__
return [mapper.get_property_by_column(column) for column in mapper.primary_key]
def parse_load(args, kwargs, context, ret):
return [to_key(row) for row in ret if hasattr(row, "__table__")]
def parse_lazy_load(args, kwargs, context):
loader, state, _ = args
return state.object.__class__, to_key(state.object), loader.parent_property.key
def parse_attribute_get(args, kwargs, context):
attr, instance = args[:2]
if instance is None:
return None
return attr.class_, attr.key, [to_key(instance)]
strategies.LazyLoader._load_for_state = signals.signalify(
signals.lazy_load,
strategies.LazyLoader._load_for_state,
parser=parse_lazy_load,
)
def parse_populate(args, kwargs, context):
query_context = args[0]
state = args[2]
instance = state.object
return instance.__class__, context["key"], [to_key(instance)], id(query_context)
# Emit `eager_load` on populating from `joinedload` or `subqueryload`
original_populate_full = loading._populate_full
def _populate_full(*args, **kwargs):
ret = original_populate_full(*args, **kwargs)
context = inspect.getcallargs(original_populate_full, *args, **kwargs)
for key, _ in context["populators"].get("eager", []):
if context["dict_"].get(key):
signals.eager_load.send(
signals.get_worker(),
args=args,
kwargs=kwargs,
context={"key": key},
parser=parse_populate,
)
return ret
loading._populate_full = _populate_full
attributes.InstrumentedAttribute.__get__ = signals.signalify(
signals.touch,
attributes.InstrumentedAttribute.__get__,
parser=parse_attribute_get,
)
def is_single(res):
limit = None
if hasattr(res, "_limit_clause"):
limit = None if res._limit_clause is None else res._limit_clause.value
offset = None
if hasattr(res, "_offset_clause"):
offset = None if res._offset_clause is None else res._offset_clause.value
return limit is not None and limit - (offset or 0) == 1
original_query_iter = query.Query._iter
def query_iter(self):
res = original_query_iter(self)
signal = signals.ignore_load if is_single(self) else signals.load
if isinstance(res, ScalarResult):
orig, clone = itertools.tee(res._real_result.iterator)
res._real_result.iterator = orig
else:
orig, clone = itertools.tee(res.iterator)
res.iterator = orig
signal.send(
signals.get_worker(),
args=(self,),
ret=list(clone),
parser=parse_load,
)
return res
query.Query._iter = query_iter
def parse_get(args, kwargs, context, ret):
return [to_key(ret)] if hasattr(ret, "__table__") else []
# Ignore records loaded during `one`
for method in ["one_or_none", "one"]:
try:
original = getattr(query.Query, method)
except AttributeError:
continue
decorated = signals.signalify(signals.ignore_load, original, parse_get)
setattr(query.Query, method, decorated)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment