Skip to content

Instantly share code, notes, and snippets.

@antonagestam
Last active May 20, 2021 09:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antonagestam/c6e0647f5cf396b18927e4b26da9149c to your computer and use it in GitHub Desktop.
Save antonagestam/c6e0647f5cf396b18927e4b26da9149c to your computer and use it in GitHub Desktop.
Introspective choice constraints for Django models.
# Note: This doesn't properly yield a delete+create when choices change.
# -> This can probably be remedied by requiring passing the model "path" (<app>.<model_name>) as
# an argument to __init__. Not ideal ...
from typing import Callable, Sequence
from django.db import models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models.constraints import BaseConstraint
from django.db.models.sql.query import Query
class LazyCheckConstraint(BaseConstraint):
def __init__(
self, *, field_name: str, check_fn: Callable[[models.Field], models.Q]
):
super().__init__(name=f"%(class)s_{field_name}_choice_chk")
self.field_name = field_name
self.check_fn = check_fn
def get_check(self, model: models.Model) -> models.Q:
return self.check_fn(model._meta.get_field(self.field_name))
def get_name(self, model: models.Model) -> str:
return self.name % {"class": model._meta.model_name}
def _get_check_sql(
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor
) -> str:
query = Query(model=model, alias_cols=False)
where = query.build_where(self.get_check(model))
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor
) -> str:
check = self._get_check_sql(model, schema_editor)
return schema_editor._check_sql(self.get_name(model), check)
def create_sql(
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor
) -> str:
check = self._get_check_sql(model, schema_editor)
return schema_editor._create_check_sql(model, self.get_name(model), check)
def remove_sql(
self, model: models.Model, schema_editor: BaseDatabaseSchemaEditor
) -> str:
return schema_editor._delete_check_sql(model, self.get_name(model))
def __repr__(self):
return (
f"<{self.__class__.__name__}: "
f"check_fn='{self.check_fn!r}' "
f"name={self.name!r}>"
)
def __eq__(self, other):
if isinstance(other, LazyCheckConstraint):
return (
self.field_name == other.field_name and self.check_fn == other.check_fn
)
return super().__eq__(other)
def deconstruct(self) -> tuple:
path, args, kwargs = super().deconstruct()
kwargs = {"field_name": self.field_name, "check_fn": self.check_fn}
return path, args, kwargs
# This needs to be exposed and accessible for migrations to import.
def build_choices_constraint(field: models.Field) -> models.Q:
"""Build a constraint check from the choices of a field."""
if not isinstance(field.choices, Sequence):
raise TypeError("Can't create choice constraint for field without choices")
return models.Q(**{f"{field.name}__in": tuple(value for value, _ in field.choices)})
def choice_constraint(field_name: str) -> models.CheckConstraint:
"""
Create a database-level choice constraint from the choices of a model field.
Usage:
>>> class B(models.TextChoices):
... p = "p"
... q = "q"
>>> class A(models.Model):
... f = models.CharField(choices=B.choices)
... class Meta:
... constraints = [choice_constraint("f")]
"""
return LazyCheckConstraint(field_name=field_name, check_fn=build_choices_constraint)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment