Last active
May 30, 2020 21:18
-
-
Save jleclanche/07961a08b7741cb804b1f14177d19a1f to your computer and use it in GitHub Desktop.
Example of generating postgres constraints functions inside Django's migrations engine (https://code.djangoproject.com/ticket/31622)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generated migration file | |
import django.db.models.constraints | |
import django.db.models.deletion | |
import django.db.models.expressions | |
import financica.contrib.constraints | |
import financica.contrib.pg_functions | |
from django.db import migrations | |
class Migration(migrations.Migration): | |
initial = True | |
dependencies = [] | |
operations = [ | |
migrations.CreateModel(...), | |
migrations.AddConstraint( | |
model_name="leg", | |
constraint=financica.contrib.pg_functions.PostgresTriggerFunctionDefinition( | |
body="DECLARE\n\ttx_id UUID;\n\tnon_zero RECORD;\nBEGIN\n\tIF (TG_OP = 'DELETE') THEN\n\t\ttx_id := OLD.transaction_id;\n\tELSE\n\t\ttx_id := NEW.transaction_id;\n\tEND IF;\n\tSELECT ABS(SUM(amount)) AS total, amount_currency AS currency\n\t\tINTO non_zero\n\t\tFROM %(table)s\n\t\tWHERE transaction_id = tx_id\n\t\tGROUP BY amount_currency\n\t\tHAVING ABS(SUM(amount)) > 0\n\t\tLIMIT 1;\n\tIF FOUND THEN\n\t\tRAISE EXCEPTION\n\t\t\t'Sum of transaction amounts in each currency must be 0. Currency %% has non-zero total %%',\n\t\t\tnon_zero.currency, non_zero.total\n\t\t\tUSING ERRCODE = 23514;\n\tEND IF;\n\tRETURN NEW;\nEND;", | |
language="plpgsql", | |
name="check_leg", | |
returns="trigger", | |
), | |
), | |
migrations.AddConstraint( | |
model_name="leg", | |
constraint=financica.contrib.pg_functions.PostgresTriggerFunctionDefinition( | |
body="DECLARE\nBEGIN\n\tIF (TG_OP = 'DELETE') THEN\n\t\tRETURN OLD;\n\tEND IF;\n\tPERFORM * FROM %(table)s WHERE uuid = NEW.account_id AND NEW.amount_currency = ANY(currencies);\n\tIF NOT FOUND THEN\n\t\tRAISE EXCEPTION\n\t\t\t'Destination account does not support currency %%', NEW.amount_currency\n\t\t\tUSING ERRCODE = 23514;\n\tEND IF;\n\tRETURN NEW;\nEND;", | |
language="plpgsql", | |
name="check_leg_and_account_currency_match", | |
returns="trigger", | |
), | |
), | |
migrations.AddConstraint( | |
model_name="leg", | |
constraint=financica.contrib.constraints.ConstraintTrigger( | |
deferrable=django.db.models.constraints.Deferrable["DEFERRED"], | |
events=("INSERT", "UPDATE", "DELETE"), | |
function=django.db.models.expressions.Func(function="check_leg"), | |
name="check_leg_trigger", | |
), | |
), | |
migrations.AddConstraint( | |
model_name="leg", | |
constraint=financica.contrib.constraints.ConstraintTrigger( | |
deferrable=django.db.models.constraints.Deferrable["DEFERRED"], | |
events=("INSERT", "UPDATE", "DELETE"), | |
function=django.db.models.expressions.Func( | |
function="check_leg_and_account_currency_match" | |
), | |
name="check_leg_and_account_currency_match_trigger", | |
), | |
), | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from typing import Any, Dict, List, Literal, Optional, Tuple, Union | |
from django.db.backends.ddl_references import Statement, Table | |
from django.db.models import Func | |
from django.db.models.constraints import BaseConstraint, Deferrable | |
from django.db.models.expressions import BaseExpression | |
from django.db.models.sql import Query | |
class TriggerEvent(Enum): | |
INSERT = "INSERT" | |
UPDATE = "UPDATE" | |
DELETE = "DELETE" | |
_TriggerEventLike = Union[Literal["INSERT", "UPDATE", "DELETE"], TriggerEvent] | |
class ConstraintTrigger(BaseConstraint): | |
template = """ | |
CREATE CONSTRAINT TRIGGER %(name)s | |
AFTER %(events)s ON %(table)s%(deferrable)s | |
FOR EACH ROW %(condition)s | |
EXECUTE PROCEDURE %(procedure)s | |
""".strip() | |
delete_template = "DROP TRIGGER %(name)s ON %(table)s" | |
def __init__( | |
self, | |
*, | |
name: str, | |
events: Union[List[_TriggerEventLike], Tuple[_TriggerEventLike, ...]], | |
function: Func, | |
condition: Optional[BaseExpression] = None, | |
deferrable: Optional[Deferrable] = None, | |
): | |
if not events: | |
raise ValueError( | |
"ConstraintTrigger events must be a list of at least one TriggerEvent" | |
) | |
self.events = tuple( | |
e.value if isinstance(e, TriggerEvent) else str(e).upper() for e in events | |
) | |
self.function = function | |
self.condition = condition | |
self.deferrable = deferrable | |
super().__init__(name) | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return ( | |
self.name == other.name | |
and set(self.events) == set(other.events) | |
and self.function == other.function | |
and self.condition == other.condition | |
and self.deferrable == other.deferrable | |
) | |
return super().__eq__(other) | |
def _get_condition_sql(self, compiler, schema_editor, query) -> str: | |
if self.condition is None: | |
return "" | |
sql, params = self.condition.as_sql(compiler, schema_editor.connection) | |
condition_sql = sql % tuple(schema_editor.quote_value(p) for p in params) | |
return "WHEN %s" % (condition_sql) | |
def _get_procedure_sql(self, compiler, schema_editor) -> str: | |
sql, params = self.function.as_sql(compiler, schema_editor.connection) | |
return sql % tuple(schema_editor.quote_value(p) for p in params) | |
def create_sql(self, model, schema_editor) -> Statement: | |
table = Table(model._meta.db_table, schema_editor.quote_name) | |
query = Query(model, alias_cols=False) | |
compiler = query.get_compiler(connection=schema_editor.connection) | |
condition = self._get_condition_sql(compiler, schema_editor, query) | |
return Statement( | |
self.template, | |
name=schema_editor.quote_name(self.name), | |
events=" OR ".join(self.events), | |
table=table, | |
condition=condition, | |
deferrable=schema_editor._deferrable_constraint_sql(self.deferrable), | |
procedure=self._get_procedure_sql(compiler, schema_editor), | |
) | |
def remove_sql(self, model, schema_editor) -> Statement: | |
return Statement( | |
self.delete_template, | |
table=Table(model._meta.db_table, schema_editor.quote_name), | |
name=schema_editor.quote_name(self.name), | |
) | |
def deconstruct(self) -> Tuple[str, Tuple[Any, ...], Dict[str, Any]]: | |
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) | |
kwargs = { | |
"name": self.name, | |
"events": self.events, | |
"function": self.function, | |
} | |
if self.condition: | |
kwargs["condition"] = self.condition | |
if self.deferrable is not None: | |
kwargs["deferrable"] = self.deferrable | |
return path, (), kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# File containing the function definitions for the constraints | |
from .pg_functions import PostgresTriggerFunctionDefinition | |
check_leg = PostgresTriggerFunctionDefinition( | |
name="check_leg", | |
body=""" | |
DECLARE | |
tx_id UUID; | |
non_zero RECORD; | |
BEGIN | |
IF (TG_OP = 'DELETE') THEN | |
tx_id := OLD.transaction_id; | |
ELSE | |
tx_id := NEW.transaction_id; | |
END IF; | |
SELECT ABS(SUM(amount)) AS total, amount_currency AS currency | |
INTO non_zero | |
FROM %(table)s | |
WHERE transaction_id = tx_id | |
GROUP BY amount_currency | |
HAVING ABS(SUM(amount)) > 0 | |
LIMIT 1; | |
IF FOUND THEN | |
RAISE EXCEPTION | |
'Sum of transaction amounts in each currency must be 0. Currency %% has non-zero total %%', | |
non_zero.currency, non_zero.total | |
USING ERRCODE = 23514; | |
END IF; | |
RETURN NEW; | |
END; | |
""", | |
) | |
check_leg_and_account_currency_match = PostgresTriggerFunctionDefinition( | |
name="check_leg_and_account_currency_match", | |
body=""" | |
DECLARE | |
BEGIN | |
IF (TG_OP = 'DELETE') THEN | |
RETURN OLD; | |
END IF; | |
PERFORM * FROM %(table)s WHERE uuid = NEW.account_id AND NEW.amount_currency = ANY(currencies); | |
IF NOT FOUND THEN | |
RAISE EXCEPTION | |
'Destination account does not support currency %%', NEW.amount_currency | |
USING ERRCODE = 23514; | |
END IF; | |
RETURN NEW; | |
END; | |
""", | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
from .db_constraints import check_leg, check_leg_and_account_currency_match | |
class Leg(models.Model): | |
... | |
class Meta: | |
constraints = [ | |
# HACK: Include function definitions as constraints so they are | |
# picked up by the migration engine. | |
check_leg, | |
check_leg_and_account_currency_match, | |
ConstraintTrigger( | |
name="check_leg_trigger", | |
events=[TriggerEvent.INSERT, TriggerEvent.UPDATE, TriggerEvent.DELETE], | |
deferrable=Deferrable.DEFERRED, | |
function=check_leg.as_func(), | |
), | |
ConstraintTrigger( | |
name="check_leg_and_account_currency_match_trigger", | |
events=[TriggerEvent.INSERT, TriggerEvent.UPDATE, TriggerEvent.DELETE], | |
deferrable=Deferrable.DEFERRED, | |
function=check_leg_and_account_currency_match.as_func(), | |
), | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db.backends.ddl_references import Statement, Table | |
from django.db.models import Func | |
def escape_literal_percent(s: str) -> str: | |
return s.replace("%", "%%") | |
class PostgresFunctionDefinition: | |
template = """ | |
CREATE FUNCTION %(name)s() RETURNS %(returns)s AS | |
$$ | |
%(body)s | |
$$ LANGUAGE %(language)s; | |
""".strip() | |
remove_template = "DROP FUNCTION %(name)s()" | |
def __init__(self, name: str, body: str, returns: str, language: str = "plpgsql"): | |
self.name = name | |
self.body = body.strip() | |
self.returns = returns | |
self.language = language | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return ( | |
self.name == other.name | |
and self.body == other.body | |
and self.returns == other.returns | |
and self.language == other.language | |
) | |
return super().__eq__(other) | |
def create_sql(self, model, schema_editor) -> Statement: | |
table = Table(model._meta.db_table, schema_editor.quote_name) | |
function_body = Statement(self.body, table=table) | |
return Statement( | |
self.template, | |
name=self.name, | |
returns=self.returns, | |
body=escape_literal_percent(str(function_body)), # Make sure to escape placeholders. | |
language=self.language, | |
) | |
def remove_sql(self, model, schema_editor) -> Statement: | |
return Statement(self.remove_template, name=self.name) | |
def clone(self): | |
_, args, kwargs = self.deconstruct() | |
return self.__class__(*args, **kwargs) | |
def deconstruct(self): | |
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) | |
return ( | |
path, | |
(), | |
{ | |
"name": self.name, | |
"body": self.body, | |
"returns": self.returns, | |
"language": self.language, | |
}, | |
) | |
def as_func(self) -> Func: | |
return Func(function=self.name) | |
class PostgresTriggerFunctionDefinition(PostgresFunctionDefinition): | |
def __init__(self, *args, **kwargs): | |
kwargs["returns"] = "trigger" | |
super().__init__(*args, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment