Skip to content

Instantly share code, notes, and snippets.

@ecederstrand
Last active September 17, 2015 20:54
Show Gist options
  • Save ecederstrand/6748a3496acdc95e40ef to your computer and use it in GitHub Desktop.
Save ecederstrand/6748a3496acdc95e40ef to your computer and use it in GitHub Desktop.
Prefetch objects using the original filters rather than the IN clause used by prefetch_related()
from collections import defaultdict
import logging
from django.db.models import Q
from django.db.models.fields.related import ForeignRelatedObjectsDescriptor, ReverseSingleRelatedObjectDescriptor, \
ReverseManyRelatedObjectsDescriptor
log = logging.getLogger(__name__)
def to_tree(prefetch_fields):
# Creates a tree structure of recursive prefetch fields given in 'foo__bar' form
tree = {}
for s in prefetch_fields:
t = tree
for f in s.split('__'):
if f not in t:
t[f] = {}
t = t[f]
return tree
def flatten(tree):
# Creates 'foo__bar' form from tree structure of recursive prefetch fields
names = []
for name, subtree in tree.items():
names.append(name)
names.extend('%s__%s' % (name, n) for n in flatten(subtree))
return names
def prefix_args(related_field, args=(), kwargs=None):
# Add prefixes to Q objects and kwargs so we can use them on a field related to the original model
new_args = []
for arg in args:
assert isinstance(arg, Q)
clone = arg.clone()
clone.children = [('%s__%s' % (related_field, k), v) for k, v in arg.children]
new_kwargs = {'%s__%s' % (related_field, k): v for k, v in kwargs.items()}
log.debug('Prefetch args/kwargs for %s: %s %s', related_field, new_args, new_kwargs)
return new_args, new_kwargs
def prefetch(model, prefetch_fields, filter_args=(), filter_kwargs=None, exclude_args=(), exclude_kwargs=None):
# Prefetches related items using the same filters as the original model, instead of the IN clause that
# QuerySet.prefetch_related() uses.
# TODO: For now, only supports recursion when the lower levels are FK relations, not M2M relations
if not filter_kwargs:
filter_kwargs = {}
if not exclude_kwargs:
exclude_kwargs = {}
log.debug('Prefetching %s on %s', prefetch_fields, model.__name__)
res = model.objects.filter(*filter_args, **filter_kwargs).exclude(*exclude_args, **exclude_kwargs)
for r in res:
r._prefetched_objects_cache = {}
related = {}
for f, select_fields in to_tree(prefetch_fields).items():
if f in related:
continue
log.debug('Getting related %s on %s', f, model.__name__)
field = getattr(model, f)
select_fields = flatten(select_fields)
if isinstance(field, ForeignRelatedObjectsDescriptor):
# Through models on 'model', aka. foorel_set, and M2M relations via reverse FK
reverse_field = field.related.field.name
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs)
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs)
id_field = field.related.field.attname
related[f] = defaultdict(set)
# Get the FK fields on the through model
related_field_names = list(f.name for f in field.related.related_model._meta.local_fields if f.rel)
# select_related() the fields we're supposed to prefetch anyway, and that we haven't fetched already
extra_select_fields = [f for f in related_field_names if f + 's' in prefetch_fields and f + 's' not in related]
extra_prefetch_names = [f + 's' for f in extra_select_fields]
for extra_prefetch_name in extra_prefetch_names:
related[extra_prefetch_name] = defaultdict(set)
select_fields += extra_select_fields
log.debug('Getting select_related %s for %s', select_fields, field.related.related_model)
for o in field.related.related_model.objects\
.select_related(*select_fields)\
.filter(*fargs, **fkwargs)\
.exclude(*xargs, **xkwargs):
related_id = getattr(o, id_field)
related[f][related_id].add(o)
for extra_select_field, extra_prefetch_name in zip(extra_select_fields, extra_prefetch_names):
related[extra_prefetch_name][related_id].add(getattr(o, extra_select_field))
elif isinstance(field, ReverseManyRelatedObjectsDescriptor):
# M2M relations via through model
reverse_field = field.field.m2m_field_name()
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs)
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs)
reverse_id_field = model.__name__.lower() + '_id'
select_field = field.field.m2m_reverse_field_name()
select_fields += [select_field]
through_model = field.through
through_field = field.through.__name__.lower() + '_set'
related[f] = defaultdict(set)
related[through_field] = defaultdict(set)
log.debug('Getting select_related %s for %s', select_fields, through_model)
for o in through_model.objects\
.select_related(*select_fields)\
.filter(*fargs, **fkwargs)\
.exclude(*xargs, **xkwargs):
related_id = getattr(o, reverse_id_field)
related[through_field][related_id].add(o)
related[f][related_id].add(getattr(o, select_field))
elif isinstance(field, ReverseSingleRelatedObjectDescriptor):
# Prefetch FK relations on 'model'
reverse_field = field.field.rel.related_name
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs)
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs)
log.debug('Getting select_related %s for %s', select_fields, field.field.rel.related_model)
related[f] = {
o.pk: o for o in field.field.rel.related_model.objects
.select_related(*select_fields)
.filter(*fargs, **fkwargs)
.exclude(*xargs, **xkwargs)
}
else:
assert False, 'Unsupported prefetch field %s' % field
for r in res:
for f in related.keys():
# Prefetched through models end with '_set' but go into _prefetched_objects_cache without the '_set'
if f.endswith('_set'):
r._prefetched_objects_cache[f[:-4]] = related[f][r.pk]
else:
r._prefetched_objects_cache[f] = related[f][r.pk]
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment