Skip to content

Instantly share code, notes, and snippets.

@jleclanche
Last active May 30, 2020 21:18
Show Gist options
  • Save jleclanche/07961a08b7741cb804b1f14177d19a1f to your computer and use it in GitHub Desktop.
Save jleclanche/07961a08b7741cb804b1f14177d19a1f to your computer and use it in GitHub Desktop.
Example of generating postgres constraints functions inside Django's migrations engine (https://code.djangoproject.com/ticket/31622)
# Generated migration file
import django.db.models.constraints
import django.db.models.deletion
import django.db.models.expressions
import financica.contrib.constraints
import financica.contrib.pg_functions
from django.db import migrations
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(...),
migrations.AddConstraint(
model_name="leg",
constraint=financica.contrib.pg_functions.PostgresTriggerFunctionDefinition(
body="DECLARE\n\ttx_id UUID;\n\tnon_zero RECORD;\nBEGIN\n\tIF (TG_OP = 'DELETE') THEN\n\t\ttx_id := OLD.transaction_id;\n\tELSE\n\t\ttx_id := NEW.transaction_id;\n\tEND IF;\n\tSELECT ABS(SUM(amount)) AS total, amount_currency AS currency\n\t\tINTO non_zero\n\t\tFROM %(table)s\n\t\tWHERE transaction_id = tx_id\n\t\tGROUP BY amount_currency\n\t\tHAVING ABS(SUM(amount)) > 0\n\t\tLIMIT 1;\n\tIF FOUND THEN\n\t\tRAISE EXCEPTION\n\t\t\t'Sum of transaction amounts in each currency must be 0. Currency %% has non-zero total %%',\n\t\t\tnon_zero.currency, non_zero.total\n\t\t\tUSING ERRCODE = 23514;\n\tEND IF;\n\tRETURN NEW;\nEND;",
language="plpgsql",
name="check_leg",
returns="trigger",
),
),
migrations.AddConstraint(
model_name="leg",
constraint=financica.contrib.pg_functions.PostgresTriggerFunctionDefinition(
body="DECLARE\nBEGIN\n\tIF (TG_OP = 'DELETE') THEN\n\t\tRETURN OLD;\n\tEND IF;\n\tPERFORM * FROM %(table)s WHERE uuid = NEW.account_id AND NEW.amount_currency = ANY(currencies);\n\tIF NOT FOUND THEN\n\t\tRAISE EXCEPTION\n\t\t\t'Destination account does not support currency %%', NEW.amount_currency\n\t\t\tUSING ERRCODE = 23514;\n\tEND IF;\n\tRETURN NEW;\nEND;",
language="plpgsql",
name="check_leg_and_account_currency_match",
returns="trigger",
),
),
migrations.AddConstraint(
model_name="leg",
constraint=financica.contrib.constraints.ConstraintTrigger(
deferrable=django.db.models.constraints.Deferrable["DEFERRED"],
events=("INSERT", "UPDATE", "DELETE"),
function=django.db.models.expressions.Func(function="check_leg"),
name="check_leg_trigger",
),
),
migrations.AddConstraint(
model_name="leg",
constraint=financica.contrib.constraints.ConstraintTrigger(
deferrable=django.db.models.constraints.Deferrable["DEFERRED"],
events=("INSERT", "UPDATE", "DELETE"),
function=django.db.models.expressions.Func(
function="check_leg_and_account_currency_match"
),
name="check_leg_and_account_currency_match_trigger",
),
),
]
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from django.db.backends.ddl_references import Statement, Table
from django.db.models import Func
from django.db.models.constraints import BaseConstraint, Deferrable
from django.db.models.expressions import BaseExpression
from django.db.models.sql import Query
class TriggerEvent(Enum):
INSERT = "INSERT"
UPDATE = "UPDATE"
DELETE = "DELETE"
_TriggerEventLike = Union[Literal["INSERT", "UPDATE", "DELETE"], TriggerEvent]
class ConstraintTrigger(BaseConstraint):
template = """
CREATE CONSTRAINT TRIGGER %(name)s
AFTER %(events)s ON %(table)s%(deferrable)s
FOR EACH ROW %(condition)s
EXECUTE PROCEDURE %(procedure)s
""".strip()
delete_template = "DROP TRIGGER %(name)s ON %(table)s"
def __init__(
self,
*,
name: str,
events: Union[List[_TriggerEventLike], Tuple[_TriggerEventLike, ...]],
function: Func,
condition: Optional[BaseExpression] = None,
deferrable: Optional[Deferrable] = None,
):
if not events:
raise ValueError(
"ConstraintTrigger events must be a list of at least one TriggerEvent"
)
self.events = tuple(
e.value if isinstance(e, TriggerEvent) else str(e).upper() for e in events
)
self.function = function
self.condition = condition
self.deferrable = deferrable
super().__init__(name)
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self.name == other.name
and set(self.events) == set(other.events)
and self.function == other.function
and self.condition == other.condition
and self.deferrable == other.deferrable
)
return super().__eq__(other)
def _get_condition_sql(self, compiler, schema_editor, query) -> str:
if self.condition is None:
return ""
sql, params = self.condition.as_sql(compiler, schema_editor.connection)
condition_sql = sql % tuple(schema_editor.quote_value(p) for p in params)
return "WHEN %s" % (condition_sql)
def _get_procedure_sql(self, compiler, schema_editor) -> str:
sql, params = self.function.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def create_sql(self, model, schema_editor) -> Statement:
table = Table(model._meta.db_table, schema_editor.quote_name)
query = Query(model, alias_cols=False)
compiler = query.get_compiler(connection=schema_editor.connection)
condition = self._get_condition_sql(compiler, schema_editor, query)
return Statement(
self.template,
name=schema_editor.quote_name(self.name),
events=" OR ".join(self.events),
table=table,
condition=condition,
deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
procedure=self._get_procedure_sql(compiler, schema_editor),
)
def remove_sql(self, model, schema_editor) -> Statement:
return Statement(
self.delete_template,
table=Table(model._meta.db_table, schema_editor.quote_name),
name=schema_editor.quote_name(self.name),
)
def deconstruct(self) -> Tuple[str, Tuple[Any, ...], Dict[str, Any]]:
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
kwargs = {
"name": self.name,
"events": self.events,
"function": self.function,
}
if self.condition:
kwargs["condition"] = self.condition
if self.deferrable is not None:
kwargs["deferrable"] = self.deferrable
return path, (), kwargs
# File containing the function definitions for the constraints
from .pg_functions import PostgresTriggerFunctionDefinition
check_leg = PostgresTriggerFunctionDefinition(
name="check_leg",
body="""
DECLARE
tx_id UUID;
non_zero RECORD;
BEGIN
IF (TG_OP = 'DELETE') THEN
tx_id := OLD.transaction_id;
ELSE
tx_id := NEW.transaction_id;
END IF;
SELECT ABS(SUM(amount)) AS total, amount_currency AS currency
INTO non_zero
FROM %(table)s
WHERE transaction_id = tx_id
GROUP BY amount_currency
HAVING ABS(SUM(amount)) > 0
LIMIT 1;
IF FOUND THEN
RAISE EXCEPTION
'Sum of transaction amounts in each currency must be 0. Currency %% has non-zero total %%',
non_zero.currency, non_zero.total
USING ERRCODE = 23514;
END IF;
RETURN NEW;
END;
""",
)
check_leg_and_account_currency_match = PostgresTriggerFunctionDefinition(
name="check_leg_and_account_currency_match",
body="""
DECLARE
BEGIN
IF (TG_OP = 'DELETE') THEN
RETURN OLD;
END IF;
PERFORM * FROM %(table)s WHERE uuid = NEW.account_id AND NEW.amount_currency = ANY(currencies);
IF NOT FOUND THEN
RAISE EXCEPTION
'Destination account does not support currency %%', NEW.amount_currency
USING ERRCODE = 23514;
END IF;
RETURN NEW;
END;
""",
)
from django.db import models
from .db_constraints import check_leg, check_leg_and_account_currency_match
class Leg(models.Model):
...
class Meta:
constraints = [
# HACK: Include function definitions as constraints so they are
# picked up by the migration engine.
check_leg,
check_leg_and_account_currency_match,
ConstraintTrigger(
name="check_leg_trigger",
events=[TriggerEvent.INSERT, TriggerEvent.UPDATE, TriggerEvent.DELETE],
deferrable=Deferrable.DEFERRED,
function=check_leg.as_func(),
),
ConstraintTrigger(
name="check_leg_and_account_currency_match_trigger",
events=[TriggerEvent.INSERT, TriggerEvent.UPDATE, TriggerEvent.DELETE],
deferrable=Deferrable.DEFERRED,
function=check_leg_and_account_currency_match.as_func(),
),
]
from django.db.backends.ddl_references import Statement, Table
from django.db.models import Func
def escape_literal_percent(s: str) -> str:
return s.replace("%", "%%")
class PostgresFunctionDefinition:
template = """
CREATE FUNCTION %(name)s() RETURNS %(returns)s AS
$$
%(body)s
$$ LANGUAGE %(language)s;
""".strip()
remove_template = "DROP FUNCTION %(name)s()"
def __init__(self, name: str, body: str, returns: str, language: str = "plpgsql"):
self.name = name
self.body = body.strip()
self.returns = returns
self.language = language
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self.name == other.name
and self.body == other.body
and self.returns == other.returns
and self.language == other.language
)
return super().__eq__(other)
def create_sql(self, model, schema_editor) -> Statement:
table = Table(model._meta.db_table, schema_editor.quote_name)
function_body = Statement(self.body, table=table)
return Statement(
self.template,
name=self.name,
returns=self.returns,
body=escape_literal_percent(str(function_body)), # Make sure to escape placeholders.
language=self.language,
)
def remove_sql(self, model, schema_editor) -> Statement:
return Statement(self.remove_template, name=self.name)
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
return (
path,
(),
{
"name": self.name,
"body": self.body,
"returns": self.returns,
"language": self.language,
},
)
def as_func(self) -> Func:
return Func(function=self.name)
class PostgresTriggerFunctionDefinition(PostgresFunctionDefinition):
def __init__(self, *args, **kwargs):
kwargs["returns"] = "trigger"
super().__init__(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment