Last active
December 28, 2015 08:26
-
-
Save tatterdemalion/1a6534005748d46b59f7 to your computer and use it in GitHub Desktop.
Combines only, select_related, prefetch_related into a single function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
class FetchManager(models.Manager): | |
""" | |
combines only, select_related, prefetch_related into a single function | |
Example models: | |
--------------- | |
class Topping(models.Model): | |
name = models.CharField(max_length=30) | |
objects = FetchManager() | |
class Chef(models.Model): | |
name = models.CharField(max_length=50) | |
unwanted_field = models.TextField() | |
objects = FetchManager() | |
class Pizza(models.Model): | |
name = models.CharField(max_length=50) | |
chef = models.ForeignKey('Chef') | |
toppings = models.ManyToManyField(Topping) | |
objects = FetchManager() | |
Populate some data: | |
------------------- | |
Chef( | |
name='John Doe', | |
unwanted_field="some big chunk of text").save() | |
Topping(name='cheddar').save() | |
Topping(name='mozarella').save() | |
Topping(name='bacon').save() | |
p = Pizza(name='Yummy', chef=Chef.objects.get()).save() | |
p.toppings.add(*Topping.objects.all()) | |
Using FetchManager: | |
------------------- | |
In [0]: Pizza.objects.fetch('name', 'toppings__name', 'chef__name') | |
Out[0]: [<Pizza: Pizza object>] | |
SQL: | |
---- | |
SELECT "pizza"."id", | |
"pizza"."name", | |
"pizza"."chef_id", | |
"chef"."id", | |
"chef"."name" | |
FROM "pizza" | |
INNER JOIN "chef" ON | |
( "pizza"."chef_id" = "chef"."id" ) | |
SELECT ("pizza_toppings"."pizza_id") AS "_prefetch_related_val_pizza_id", | |
"topping"."id", | |
"topping"."name" | |
FROM "topping" | |
INNER JOIN "pizza_toppings" ON | |
( "topping"."id" = "pizza_toppings"."topping_id" ) | |
WHERE "pizza_toppings"."pizza_id" IN (1) | |
""" | |
def fetch(self, *args, **kwargs): | |
def is_m2m(model, relations): | |
""" | |
traverses all relations and checks for any m2m field | |
""" | |
fields = {i.name: i for i in model._meta.fields if i.is_relation} | |
m2m_fields = [i[0].name for i in model._meta.get_m2m_with_model()] | |
contains_m2m = False | |
if relations and relations[0] in m2m_fields: | |
contains_m2m = True | |
return contains_m2m | |
elif relations and relations[0] in fields: | |
new_model = fields[relations[0]].related_model | |
relations.pop(0) | |
return is_m2m(new_model, relations) | |
return contains_m2m | |
# {'relation__relation__field': ['model', 'relation']} | |
related_fields = {} | |
for arg in args: | |
if '__' in arg: | |
rel = arg[:arg.rfind('__')] | |
related_fields[rel] = rel.split('__') | |
select = [] | |
prefetch = [] | |
for rel, relations in related_fields.items(): | |
if is_m2m(self.model, relations): | |
prefetch.append(rel) | |
# TODO:make use of Prefetch object | |
else: | |
select.append(rel) | |
query = self.select_related( | |
*select).prefetch_related( | |
*prefetch).only(*args) | |
return query |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment