Last active
July 16, 2020 19:52
-
-
Save JohnSpeno/5ac4d49b5c16d9193c8a5ff0d20bad7d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import django | |
from django.db import transaction | |
from django.db.models import Count, Func, IntegerField, OuterRef, Subquery | |
from django.contrib.postgres.aggregates import ArrayAgg | |
from django.contrib.postgres.fields import ArrayField | |
from sql_util.utils import SubqueryAggregate | |
django.setup() | |
from pizza.models import Topping, Pizza | |
""" | |
Our lovely models: | |
class Topping(models.Model): | |
name = models.CharField(max_length=30) | |
def __str__(self): | |
return self.name | |
class Pizza(models.Model): | |
name = models.CharField(max_length=50) | |
toppings = models.ManyToManyField(Topping, related_name="pizzas") | |
def __str__(self): | |
return "%s (%s)" % ( | |
self.name, | |
", ".join(topping.name for topping in self.toppings.all()), | |
) | |
""" | |
def set_up_pizzas(): | |
""" | |
Ensure the DB is populated with Toppings and Pizzas. | |
Returns a tuple of (list of wanted toppings, wanted_pizza) | |
In this case, the toppings we want are cheese and onions, | |
and the pizza wanted is the "onion" variety. | |
""" | |
cheese, _ = Topping.objects.get_or_create(name="cheese") | |
mushrooms, _ = Topping.objects.get_or_create(name="mushrooms") | |
onions, _ = Topping.objects.get_or_create(name="onions") | |
peppers, _ = Topping.objects.get_or_create(name="peppers") | |
ep, created = Pizza.objects.get_or_create(name="everything") | |
if created: | |
for topping in [cheese, mushrooms, onions, peppers]: | |
ep.toppings.add(topping) | |
mp, created = Pizza.objects.get_or_create(name="mushroom") | |
if created: | |
mp.toppings.add(cheese) | |
mp.toppings.add(mushrooms) | |
op, created = Pizza.objects.get_or_create(name="onion") | |
if created: | |
op.toppings.add(cheese) | |
op.toppings.add(onions) | |
wanted_toppings = [cheese, onions] | |
return wanted_toppings, op | |
def pizzas_by_counting_and_filtering(wanted_toppings): | |
""" | |
This solution uses a Count() annotation with filter(). | |
It works when the wanted toppings are in any order. | |
Each wanted topping adds a JOIN to the SQL query. | |
""" | |
available_pizzas = Pizza.objects.annotate( | |
num_toppings=Count("toppings"), | |
).filter( | |
num_toppings=len(wanted_toppings), | |
) | |
for topping in wanted_toppings: | |
available_pizzas = available_pizzas.filter(toppings=topping) | |
return available_pizzas | |
class Array(Func): | |
function = 'ARRAY' | |
template = '%(function)s[%(expressions)s]' | |
def pizzas_by_sql_utils_subquery_aggregate(wanted_toppings): | |
""" | |
This solution uses the django-sql-utils module's SubqueryAggregate() | |
to simplify the writing of subqueries with aggregations. | |
As written, requires Django 3 or later, only finds pizzas with toppings | |
in the given order, and it only works with Postgresql. | |
Issue was fixed in 521308e5 - https://code.djangoproject.com/ticket/30715 | |
""" | |
pizzas = Pizza.objects.annotate( | |
its_toppings=SubqueryAggregate('toppings__id', aggregate=ArrayAgg), | |
) | |
cheese, onions = wanted_toppings | |
return pizzas.filter(its_toppings=Array(cheese.id, onions.id)) | |
def pizzas_by_subquery_arrayagg(wanted_toppings): | |
""" | |
This solution uses a Subquery() with an ArrayAgg(). I think it is the | |
Django 2.x equivalent of `pizzas_by_arrayagg()` which requires Django 3.x. | |
As written, it only finds pizzas with toppings in the given order, and | |
it only works with Postgresql. | |
""" | |
subquery = Subquery( | |
Topping.objects.filter( | |
pizzas=OuterRef('id'), | |
).order_by() | |
.values( | |
'pizzas' | |
).annotate( | |
its_toppings=ArrayAgg('id'), | |
) | |
.values( | |
'its_toppings', | |
), | |
) | |
pizzas = Pizza.objects.annotate(its_toppings=subquery) | |
cheese, onions = wanted_toppings | |
pizzas = pizzas.filter(its_toppings=Array(cheese.id, onions.id)) | |
return pizzas | |
def pizzas_by_arrayagg(wanted_toppings): | |
""" | |
This solution uses an ArrayAgg annotation with a single filter. | |
As written, requires Django 3 or later, only finds pizzas with toppings | |
in the given order, and it only works with Postgresql. | |
Issue was fixed in 521308e5 - https://code.djangoproject.com/ticket/30715 | |
""" | |
pizzas = Pizza.objects.annotate(its_toppings=ArrayAgg("toppings")) | |
pizzas = pizzas.filter(its_toppings=[topping.id for topping in wanted_toppings]) | |
return pizzas | |
def find_wanted_pizzas(): | |
wanted_toppings, wanted_pizza = set_up_pizzas() | |
print(f"The pizza we want is: {wanted_pizza}") | |
print() | |
simple = pizzas_by_counting_and_filtering(wanted_toppings) | |
print(f"pizzas_by_counting_and_filtering returned: {simple}") | |
print() | |
subq_arrayagg = pizzas_by_subquery_arrayagg(wanted_toppings) | |
print(f"pizzas_by_subquery_arrayagg returned: {subq_arrayagg}") | |
print() | |
if django.VERSION < (3, 0, 3): | |
print( | |
"pizzas_by_sql_utils_subquery_aggregate and " | |
"pizzas_by_arrayagg require django 3.0.3 or later." | |
) | |
print(f"they were not run because we are running django {django.get_version()}") | |
return | |
# The remaining techniques require Django 3 or later | |
subq = pizzas_by_sql_utils_subquery_aggregate(wanted_toppings) | |
print(f"pizzas_by_sql_utils_subquery_aggregate returned: {subq}") | |
print() | |
agg = pizzas_by_arrayagg(wanted_toppings) | |
print(f"pizzas_by_arrayagg returned: {agg}") | |
print() | |
if __name__ == "__main__": | |
find_wanted_pizzas() |
I tracked down the issue preventing some of the techniques from working in Django 2, which is fixed in Django 3.0.3 but did not qualify for a backport to Django 2.2 (alas). https://code.djangoproject.com/ticket/30715
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I wanted to answer the question: Is there a way to filter ManyToManyFields by exact set? for instance,
.filter(m2m=foo)
will find all objects with a relation tofoo
, but what if I want onlyfoo
or onlyfoo
andbar
? I found a few options, some of which I couldn't get to work in either Django 2.2.x or Django 3.0.x until I posted about it here: https://forum.djangoproject.com/t/a-specific-manytomany-query-crash/3457