-
-
Save petehellyer/e66674c0e5ca15c000219b5e8af9310d to your computer and use it in GitHub Desktop.
Django table function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
############################# | |
# Table Functions | |
############################# | |
''' | |
CREATE OR REPLACE FUNCTION get_user(INTEGER DEFAULT NULL, VARCHAR DEFAULT NULL) | |
RETURNS TABLE( | |
id INTEGER, | |
user_id INTEGER, | |
one_id INTEGER, | |
username VARCHAR, | |
parent_id INTEGER, | |
root_id INTEGER | |
) AS $$ | |
SELECT id, id, id, username, id, id | |
FROM auth_user | |
WHERE ( | |
$1 IS NULL OR id = $1 | |
) AND ( | |
$2 IS NULL OR username = $2 | |
) | |
$$ LANGUAGE sql; | |
''' | |
''' | |
CREATE OR REPLACE FUNCTION get_user_2(INTEGER, VARCHAR) | |
RETURNS TABLE( | |
id INTEGER, | |
user_id INTEGER, | |
username VARCHAR | |
) AS $$ | |
SELECT id, id, username | |
FROM auth_user | |
WHERE id = $1 AND username = $2 | |
$$ LANGUAGE sql; | |
''' | |
############################# | |
# Models | |
############################# | |
class TableFunctionUser(models.Model): | |
function_args = OrderedDict(( | |
('id', TableFunctionArg(required=False)), | |
('username', TableFunctionArg(required=False)), | |
)) | |
objects = TableFunctionManager() | |
user = models.ForeignKey('auth.User', on_delete=models.DO_NOTHING) | |
one = models.OneToOneField('auth.User', on_delete=models.DO_NOTHING) | |
username = models.CharField(max_length=150) | |
parent = models.ForeignKey('self', on_delete=models.DO_NOTHING) | |
root = models.ForeignKey('self', on_delete=models.DO_NOTHING) | |
class Meta: | |
db_table = 'get_user' | |
managed = False | |
def __str__(self): | |
return self.username | |
class TableFunctionUser2(models.Model): | |
function_args = OrderedDict(( | |
('id', TableFunctionArg()), | |
('username', TableFunctionArg(default='admin')), | |
)) | |
objects = TableFunctionManager() | |
user = models.ForeignKey('auth.User', on_delete=models.DO_NOTHING) | |
username = models.CharField(max_length=150) | |
class Meta: | |
db_table = 'get_user_2' | |
managed = False | |
def __str__(self): | |
return self.username | |
############################# | |
# API | |
############################# | |
def test_optional_args(self): | |
user_count = User.objects.count() | |
function_user_count = TableFunctionUser.objects.count() | |
assert user_count == function_user_count | |
user = User.objects.get(id=1) | |
fuser = TableFunctionUser.objects.all().table_function(id=1, username='admin').get() | |
assert user == fuser.user | |
def test_select_related(self): | |
user = User.objects.get(id=1) | |
fuser = TableFunctionUser.objects.all()\ | |
.table_function(id=1, root__id=1, parent__id=1, parent__parent__id=1)\ | |
.select_related('root', 'parent__parent__user')\ | |
.get() | |
assert user == fuser.parent.parent.user | |
fuser = TableFunctionUser.objects.all()\ | |
.table_function(id=1, parent__id=models.F('id'))\ | |
.select_related('parent', 'parent__user')\ | |
.get() | |
assert user == fuser.parent.user | |
fuser = TableFunctionUser.objects.all().table_function(id=1, parent__id=1, root__id=1)\ | |
.select_related('user', 'one', 'parent', 'root')\ | |
.get() | |
assert fuser == fuser.root | |
def test_annotate(self): | |
user = User.objects.get(id=1) | |
fuser = TableFunctionUser.objects\ | |
.all()\ | |
.table_function(id=1)\ | |
.annotate(_username=models.F('user__username'))\ | |
.get() | |
assert user.username == fuser._username | |
def test_required_arg(self): | |
with pytest.raises(ProgrammingError): | |
len(TableFunctionUser2.objects.count()) | |
def test_default_arg(self): | |
user = User.objects.get(id=1) | |
fuser = TableFunctionUser2.objects.all().table_function(id=1).get() | |
assert user == fuser.user | |
def test_override_default_arg(self): | |
user = User.objects.get(id=2) | |
fuser = TableFunctionUser2.objects.all().table_function(id=2, username='AnonymousUser').get() | |
assert user == fuser.user | |
############################# | |
# Django extending | |
############################# | |
import re | |
from collections import OrderedDict | |
from typing import Any, Dict, Type, Optional | |
from typing import List | |
from django.db.models import ForeignObject | |
from django.db.models import QuerySet, Manager, NOT_PROVIDED, F, Model | |
from django.db.models.constants import LOOKUP_SEP | |
from django.db.models.sql import Query | |
from django.db.models.sql.datastructures import BaseTable, Join | |
from django.db.models.sql.where import WhereNode | |
class TableFunctionArg: | |
def __init__(self, required: bool = True, default=NOT_PROVIDED): | |
self.required = required # type: bool | |
self.default = default | |
class TableFunction(BaseTable): | |
def __init__(self, table_name: str, alias: Optional[str], table_function_params: List[Any]): | |
super().__init__(table_name, alias) | |
self.table_function_params = table_function_params # type: List[Any] | |
def as_sql(self, compiler, connection): | |
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) | |
base_sql = compiler.quote_name_unless_alias(self.table_name) | |
return '{}({}){}'.format( | |
base_sql, | |
', '.join(['%s' for _ in range(len(self.table_function_params))]), | |
alias_str | |
), self.table_function_params | |
class TableFunctionJoin(Join): | |
def __init__(self, table_name, parent_alias, table_alias, join_type, | |
join_field, nullable, filtered_relation=None, table_function_params: List[Any] = None): | |
super().__init__(table_name, parent_alias, table_alias, join_type, | |
join_field, nullable, filtered_relation) | |
self.table_function_params = table_function_params # type: List[Any] | |
def as_sql(self, compiler, connection): | |
sql, params = super().as_sql(compiler, connection) | |
if self.table_function_params is None: | |
return sql, params # normal table join | |
# extract `on_clause_sql` from ancestor's complex compiled query logic | |
# to be able pass function instead of normal table into sql easily | |
result = re.match('.+?join.+?on(?P<on_clause_sql>.+)', sql, re.IGNORECASE | re.DOTALL) | |
on_clause_sql = result.group('on_clause_sql') | |
table_function_placeholders = [] | |
table_function_params = [] | |
for param in self.table_function_params: | |
if hasattr(param, 'as_sql'): | |
param_sql, param_params = param.as_sql(compiler, connection) | |
else: | |
param_sql = '%s' | |
param_params = [param] | |
table_function_placeholders.append(param_sql) | |
table_function_params += param_params | |
sql = '{} {}({}) {} ON ({})'.format( | |
self.join_type, | |
compiler.quote_name_unless_alias(self.table_name), | |
', '.join(table_function_placeholders), | |
self.table_alias, | |
on_clause_sql | |
) | |
return sql, table_function_params + params | |
class TableFunctionParams: | |
def __init__(self, level: int, join_field: ForeignObject, params: 'OrderedDict[str, Any]'): | |
self.level = level # type: int | |
self.join_field = join_field # type: ForeignObject | |
self.params = params # type: OrderedDict[str, Any] | |
class TableFunctionQuery(Query): | |
def __init__(self, model, where=WhereNode): | |
super().__init__(model, where) | |
self.table_function_params = [] # type: List[TableFunctionParams] | |
def get_initial_alias(self): | |
if self.alias_map: | |
alias = self.base_table | |
self.ref_alias(alias) | |
else: | |
if hasattr(self.model, 'function_args'): | |
try: | |
params = list( | |
next(filter(lambda x: x.level == 0, self.table_function_params)).params.values() | |
) # type: List[Any] | |
except StopIteration: | |
# no parameters were passed from user | |
# so try to call the function without parameters | |
# in case that they are optional | |
params = [] | |
alias = self.join(TableFunction(self.get_meta().db_table, None, params)) | |
else: | |
alias = self.join(BaseTable(self.get_meta().db_table, None)) | |
return alias | |
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, | |
reuse_with_filtered_relation=False): | |
join_info = super().setup_joins(names, opts, alias, can_reuse, allow_many, reuse_with_filtered_relation) | |
level = 0 | |
for alias in join_info.joins: | |
join = self.alias_map[alias] | |
if isinstance(join, TableFunction): | |
continue # skip the `FROM func(...)`, it is handled in `get_initial_alias` | |
if not hasattr(join.join_field.related_model, 'function_args'): | |
continue # skip normal tables | |
level += 1 | |
try: | |
params = list(next(filter( | |
lambda x: x.level == level and x.join_field == join.join_field, | |
self.table_function_params | |
)).params.values()) # type: List[Any] | |
except StopIteration: | |
# no parameters were passed from user | |
# so try to call the function without parameters | |
# in case that they are optional | |
params = [] | |
resolved_params = [] | |
for param in params: | |
if isinstance(param, F): | |
resolved_param = param.resolve_expression(self) | |
else: | |
resolved_param = param | |
resolved_params.append(resolved_param) | |
self.alias_map[alias] = TableFunctionJoin( | |
join.table_name, join.parent_alias, join.table_alias, join.join_type, join.join_field, | |
join.nullable, join.filtered_relation, resolved_params | |
) | |
return join_info | |
def table_function(self, **table_function_params: Dict[str, Any]): | |
""" | |
Take user's passed params and store them in `self.table_function_params` | |
to be prepared for joining. | |
""" | |
_table_function_params = [] | |
for table_lookup, param_dict in self._table_function_params_to_groups(table_function_params).items(): | |
if not table_lookup: | |
level = 0 | |
join_field = None | |
model = self.model | |
else: | |
level = len(table_lookup.split(LOOKUP_SEP)) | |
lookup_parts, field_parts, _ = self.solve_lookup_type(table_lookup) | |
path, final_field, targets, rest = self.names_to_path( | |
field_parts, self.get_meta(), allow_many=False, fail_on_missing=True | |
) | |
join_field = path[-1].join_field | |
model = final_field.related_model | |
_table_function_params.append( | |
TableFunctionParams( | |
level=level, join_field=join_field, | |
params=self._reorder_table_function_params(model, param_dict) | |
) | |
) | |
# TODO: merge with existing? | |
self.table_function_params = _table_function_params | |
def _table_function_params_to_groups(self, table_function_params: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Transfer user specified lookups into groups | |
to have all parameters for each table function prepared for joining. | |
{id: 1, parent__id: 2, parent__code=3, parent__parent__id=4, root__id=5} | |
=> | |
{ | |
'': {'id': 1}, | |
'parent': {'id': 2, 'code': 3}, | |
'parent__parent': {'id': 4}, | |
'root': {'id: 5} | |
} | |
""" | |
param_groups = {} | |
for lookup, val in table_function_params.items(): | |
parts = lookup.split(LOOKUP_SEP) | |
prefix = LOOKUP_SEP.join(parts[:-1]) | |
field = parts[-1] | |
if prefix not in param_groups: | |
param_groups[prefix] = {} | |
param_groups[prefix][field] = val | |
return param_groups | |
def _reorder_table_function_params( | |
self, model: Type[Model], table_function_params: Dict[str, Any] | |
) -> 'OrderedDict[str, Any]': | |
""" | |
Make sure that parameters will be passed into function in correct order. | |
Also check required and set defaults. | |
""" | |
ordered_function_params = OrderedDict() | |
for key, arg in getattr(model, 'function_args').items(): | |
if key in table_function_params: | |
ordered_function_params[key] = table_function_params[key] | |
elif arg.default != NOT_PROVIDED: | |
ordered_function_params[key] = arg.default | |
elif arg.required: | |
raise ValueError('Required function arg `{}` not specified'.format(key)) | |
remaining = set(table_function_params.keys()) - set(ordered_function_params.keys()) | |
if remaining: | |
raise ValueError('Function arg `{}` not found'.format(remaining.pop())) | |
return ordered_function_params | |
class TableFunctionQuerySet(QuerySet): | |
def __init__(self, model=None, query=None, using=None, hints=None): | |
super().__init__(model, query, using, hints) | |
self.query = query or TableFunctionQuery(self.model) | |
def table_function(self, **table_function_params: Dict[str, Any]) -> 'TableFunctionQuerySet': | |
self.query.table_function(**table_function_params) | |
return self | |
class TableFunctionManager(Manager): | |
def get_queryset(self) -> TableFunctionQuerySet: | |
return TableFunctionQuerySet(model=self.model, using=self._db, hints=self._hints) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment