Last active
May 20, 2021 09:16
-
-
Save antonagestam/c6e0647f5cf396b18927e4b26da9149c to your computer and use it in GitHub Desktop.
Introspective choice constraints for Django models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Note: This doesn't properly yield a delete+create when choices change. | |
# -> This can probably be remedied by requiring passing the model "path" (<app>.<model_name>) as | |
# an argument to __init__. Not ideal ... | |
from typing import Callable, Sequence | |
from django.db import models | |
from django.db.backends.base.schema import BaseDatabaseSchemaEditor | |
from django.db.models.constraints import BaseConstraint | |
from django.db.models.sql.query import Query | |
class LazyCheckConstraint(BaseConstraint): | |
def __init__( | |
self, *, field_name: str, check_fn: Callable[[models.Field], models.Q] | |
): | |
super().__init__(name=f"%(class)s_{field_name}_choice_chk") | |
self.field_name = field_name | |
self.check_fn = check_fn | |
def get_check(self, model: models.Model) -> models.Q: | |
return self.check_fn(model._meta.get_field(self.field_name)) | |
def get_name(self, model: models.Model) -> str: | |
return self.name % {"class": model._meta.model_name} | |
def _get_check_sql( | |
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor | |
) -> str: | |
query = Query(model=model, alias_cols=False) | |
where = query.build_where(self.get_check(model)) | |
compiler = query.get_compiler(connection=schema_editor.connection) | |
sql, params = where.as_sql(compiler, schema_editor.connection) | |
return sql % tuple(schema_editor.quote_value(p) for p in params) | |
def constraint_sql( | |
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor | |
) -> str: | |
check = self._get_check_sql(model, schema_editor) | |
return schema_editor._check_sql(self.get_name(model), check) | |
def create_sql( | |
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor | |
) -> str: | |
check = self._get_check_sql(model, schema_editor) | |
return schema_editor._create_check_sql(model, self.get_name(model), check) | |
def remove_sql( | |
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor | |
) -> str: | |
return schema_editor._delete_check_sql(model, self.get_name(model)) | |
def __repr__(self): | |
return ( | |
f"<{self.__class__.__name__}: " | |
f"check_fn='{self.check_fn!r}' " | |
f"name={self.name!r}>" | |
) | |
def __eq__(self, other): | |
if isinstance(other, LazyCheckConstraint): | |
return ( | |
self.field_name == other.field_name and self.check_fn == other.check_fn | |
) | |
return super().__eq__(other) | |
def deconstruct(self) -> tuple: | |
path, args, kwargs = super().deconstruct() | |
kwargs = {"field_name": self.field_name, "check_fn": self.check_fn} | |
return path, args, kwargs | |
# This needs to be exposed and accessible for migrations to import. | |
def build_choices_constraint(field: models.Field) -> models.Q: | |
"""Build a constraint check from the choices of a field.""" | |
if not isinstance(field.choices, Sequence): | |
raise TypeError("Can't create choice constraint for field without choices") | |
return models.Q(**{f"{field.name}__in": tuple(value for value, _ in field.choices)}) | |
def choice_constraint(field_name: str) -> models.CheckConstraint: | |
""" | |
Create a database-level choice constraint from the choices of a model field. | |
Usage: | |
>>> class B(models.TextChoices): | |
... p = "p" | |
... q = "q" | |
>>> class A(models.Model): | |
... f = models.CharField(choices=B.choices) | |
... class Meta: | |
... constraints = [choice_constraint("f")] | |
""" | |
return LazyCheckConstraint(field_name=field_name, check_fn=build_choices_constraint) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment