Skip to content

Instantly share code, notes, and snippets.

@petrprikryl
Last active June 26, 2023 06:31
Show Gist options
  • Save petrprikryl/7cd765cd723c7df983de03706bf27d1a to your computer and use it in GitHub Desktop.
Save petrprikryl/7cd765cd723c7df983de03706bf27d1a to your computer and use it in GitHub Desktop.
Django table function
#############################
# Table Functions
#############################
'''
CREATE OR REPLACE FUNCTION get_user(INTEGER DEFAULT NULL, VARCHAR DEFAULT NULL)
RETURNS TABLE(
id INTEGER,
user_id INTEGER,
one_id INTEGER,
username VARCHAR,
parent_id INTEGER,
root_id INTEGER
) AS $$
SELECT id, id, id, username, id, id
FROM auth_user
WHERE (
$1 IS NULL OR id = $1
) AND (
$2 IS NULL OR username = $2
)
$$ LANGUAGE sql;
'''
'''
CREATE OR REPLACE FUNCTION get_user_2(INTEGER, VARCHAR)
RETURNS TABLE(
id INTEGER,
user_id INTEGER,
username VARCHAR
) AS $$
SELECT id, id, username
FROM auth_user
WHERE id = $1 AND username = $2
$$ LANGUAGE sql;
'''
#############################
# Models
#############################
class TableFunctionUser(models.Model):
function_args = OrderedDict((
('id', TableFunctionArg(required=False)),
('username', TableFunctionArg(required=False)),
))
objects = TableFunctionManager()
user = models.ForeignKey('auth.User', on_delete=models.DO_NOTHING)
one = models.OneToOneField('auth.User', on_delete=models.DO_NOTHING)
username = models.CharField(max_length=150)
parent = models.ForeignKey('self', on_delete=models.DO_NOTHING)
root = models.ForeignKey('self', on_delete=models.DO_NOTHING)
class Meta:
db_table = 'get_user'
managed = False
def __str__(self):
return self.username
class TableFunctionUser2(models.Model):
function_args = OrderedDict((
('id', TableFunctionArg()),
('username', TableFunctionArg(default='admin')),
))
objects = TableFunctionManager()
user = models.ForeignKey('auth.User', on_delete=models.DO_NOTHING)
username = models.CharField(max_length=150)
class Meta:
db_table = 'get_user_2'
managed = False
def __str__(self):
return self.username
#############################
# API
#############################
def test_optional_args(self):
user_count = User.objects.count()
function_user_count = TableFunctionUser.objects.count()
assert user_count == function_user_count
user = User.objects.get(id=1)
fuser = TableFunctionUser.objects.all().table_function(id=1, username='admin').get()
assert user == fuser.user
def test_select_related(self):
user = User.objects.get(id=1)
fuser = TableFunctionUser.objects.all()\
.table_function(id=1, root__id=1, parent__id=1, parent__parent__id=1)\
.select_related('root', 'parent__parent__user')\
.get()
assert user == fuser.parent.parent.user
fuser = TableFunctionUser.objects.all()\
.table_function(id=1, parent__id=models.F('id'))\
.select_related('parent', 'parent__user')\
.get()
assert user == fuser.parent.user
fuser = TableFunctionUser.objects.all().table_function(id=1, parent__id=1, root__id=1)\
.select_related('user', 'one', 'parent', 'root')\
.get()
assert fuser == fuser.root
def test_annotate(self):
user = User.objects.get(id=1)
fuser = TableFunctionUser.objects\
.all()\
.table_function(id=1)\
.annotate(_username=models.F('user__username'))\
.get()
assert user.username == fuser._username
def test_required_arg(self):
with pytest.raises(ProgrammingError):
len(TableFunctionUser2.objects.count())
def test_default_arg(self):
user = User.objects.get(id=1)
fuser = TableFunctionUser2.objects.all().table_function(id=1).get()
assert user == fuser.user
def test_override_default_arg(self):
user = User.objects.get(id=2)
fuser = TableFunctionUser2.objects.all().table_function(id=2, username='AnonymousUser').get()
assert user == fuser.user
#############################
# Django extending
#############################
import re
from collections import OrderedDict
from typing import Any, Dict, Type, Optional
from typing import List
from django.db.models import ForeignObject
from django.db.models import QuerySet, Manager, NOT_PROVIDED, F, Model
from django.db.models.constants import LOOKUP_SEP
from django.db.models.sql import Query
from django.db.models.sql.datastructures import BaseTable, Join
from django.db.models.sql.where import WhereNode
class TableFunctionArg:
def __init__(self, required: bool = True, default=NOT_PROVIDED):
self.required = required # type: bool
self.default = default
class TableFunction(BaseTable):
def __init__(self, table_name: str, alias: Optional[str], table_function_params: List[Any]):
super().__init__(table_name, alias)
self.table_function_params = table_function_params # type: List[Any]
def as_sql(self, compiler, connection):
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
base_sql = compiler.quote_name_unless_alias(self.table_name)
return '{}({}){}'.format(
base_sql,
', '.join(['%s' for _ in range(len(self.table_function_params))]),
alias_str
), self.table_function_params
class TableFunctionJoin(Join):
def __init__(self, table_name, parent_alias, table_alias, join_type,
join_field, nullable, filtered_relation=None, table_function_params: List[Any] = None):
super().__init__(table_name, parent_alias, table_alias, join_type,
join_field, nullable, filtered_relation)
self.table_function_params = table_function_params # type: List[Any]
def as_sql(self, compiler, connection):
sql, params = super().as_sql(compiler, connection)
if self.table_function_params is None:
return sql, params # normal table join
# extract `on_clause_sql` from ancestor's complex compiled query logic
# to be able pass function instead of normal table into sql easily
result = re.match('.+?join.+?on(?P<on_clause_sql>.+)', sql, re.IGNORECASE | re.DOTALL)
on_clause_sql = result.group('on_clause_sql')
table_function_placeholders = []
table_function_params = []
for param in self.table_function_params:
if hasattr(param, 'as_sql'):
param_sql, param_params = param.as_sql(compiler, connection)
else:
param_sql = '%s'
param_params = [param]
table_function_placeholders.append(param_sql)
table_function_params += param_params
sql = '{} {}({}) {} ON ({})'.format(
self.join_type,
compiler.quote_name_unless_alias(self.table_name),
', '.join(table_function_placeholders),
self.table_alias,
on_clause_sql
)
return sql, table_function_params + params
class TableFunctionParams:
def __init__(self, level: int, join_field: ForeignObject, params: 'OrderedDict[str, Any]'):
self.level = level # type: int
self.join_field = join_field # type: ForeignObject
self.params = params # type: OrderedDict[str, Any]
class TableFunctionQuery(Query):
def __init__(self, model, where=WhereNode):
super().__init__(model, where)
self.table_function_params = [] # type: List[TableFunctionParams]
def get_initial_alias(self):
if self.alias_map:
alias = self.base_table
self.ref_alias(alias)
else:
if hasattr(self.model, 'function_args'):
try:
params = list(
next(filter(lambda x: x.level == 0, self.table_function_params)).params.values()
) # type: List[Any]
except StopIteration:
# no parameters were passed from user
# so try to call the function without parameters
# in case that they are optional
params = []
alias = self.join(TableFunction(self.get_meta().db_table, None, params))
else:
alias = self.join(BaseTable(self.get_meta().db_table, None))
return alias
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
reuse_with_filtered_relation=False):
join_info = super().setup_joins(names, opts, alias, can_reuse, allow_many, reuse_with_filtered_relation)
level = 0
for alias in join_info.joins:
join = self.alias_map[alias]
if isinstance(join, TableFunction):
continue # skip the `FROM func(...)`, it is handled in `get_initial_alias`
if not hasattr(join.join_field.related_model, 'function_args'):
continue # skip normal tables
level += 1
try:
params = list(next(filter(
lambda x: x.level == level and x.join_field == join.join_field,
self.table_function_params
)).params.values()) # type: List[Any]
except StopIteration:
# no parameters were passed from user
# so try to call the function without parameters
# in case that they are optional
params = []
resolved_params = []
for param in params:
if isinstance(param, F):
resolved_param = param.resolve_expression(self)
else:
resolved_param = param
resolved_params.append(resolved_param)
self.alias_map[alias] = TableFunctionJoin(
join.table_name, join.parent_alias, join.table_alias, join.join_type, join.join_field,
join.nullable, join.filtered_relation, resolved_params
)
return join_info
def table_function(self, **table_function_params: Dict[str, Any]):
"""
Take user's passed params and store them in `self.table_function_params`
to be prepared for joining.
"""
_table_function_params = []
for table_lookup, param_dict in self._table_function_params_to_groups(table_function_params).items():
if not table_lookup:
level = 0
join_field = None
model = self.model
else:
level = len(table_lookup.split(LOOKUP_SEP))
lookup_parts, field_parts, _ = self.solve_lookup_type(table_lookup)
path, final_field, targets, rest = self.names_to_path(
field_parts, self.get_meta(), allow_many=False, fail_on_missing=True
)
join_field = path[-1].join_field
model = final_field.related_model
_table_function_params.append(
TableFunctionParams(
level=level, join_field=join_field,
params=self._reorder_table_function_params(model, param_dict)
)
)
# TODO: merge with existing?
self.table_function_params = _table_function_params
def _table_function_params_to_groups(self, table_function_params: Dict[str, Any]) -> Dict[str, Any]:
"""
Transfer user specified lookups into groups
to have all parameters for each table function prepared for joining.
{id: 1, parent__id: 2, parent__code=3, parent__parent__id=4, root__id=5}
=>
{
'': {'id': 1},
'parent': {'id': 2, 'code': 3},
'parent__parent': {'id': 4},
'root': {'id: 5}
}
"""
param_groups = {}
for lookup, val in table_function_params.items():
parts = lookup.split(LOOKUP_SEP)
prefix = LOOKUP_SEP.join(parts[:-1])
field = parts[-1]
if prefix not in param_groups:
param_groups[prefix] = {}
param_groups[prefix][field] = val
return param_groups
def _reorder_table_function_params(
self, model: Type[Model], table_function_params: Dict[str, Any]
) -> 'OrderedDict[str, Any]':
"""
Make sure that parameters will be passed into function in correct order.
Also check required and set defaults.
"""
ordered_function_params = OrderedDict()
for key, arg in getattr(model, 'function_args').items():
if key in table_function_params:
ordered_function_params[key] = table_function_params[key]
elif arg.default != NOT_PROVIDED:
ordered_function_params[key] = arg.default
elif arg.required:
raise ValueError('Required function arg `{}` not specified'.format(key))
remaining = set(table_function_params.keys()) - set(ordered_function_params.keys())
if remaining:
raise ValueError('Function arg `{}` not found'.format(remaining.pop()))
return ordered_function_params
class TableFunctionQuerySet(QuerySet):
def __init__(self, model=None, query=None, using=None, hints=None):
super().__init__(model, query, using, hints)
self.query = query or TableFunctionQuery(self.model)
def table_function(self, **table_function_params: Dict[str, Any]) -> 'TableFunctionQuerySet':
self.query.table_function(**table_function_params)
return self
class TableFunctionManager(Manager):
def get_queryset(self) -> TableFunctionQuerySet:
return TableFunctionQuerySet(model=self.model, using=self._db, hints=self._hints)
@fzzylogic
Copy link

This is great, Django should have this built in ^^. Linked to it from a Stackoverflow answer, would be great if you'd attach license info. Thanks for sharing!

@petrprikryl
Copy link
Author

It's free to use.

Btw. there is problem in this implementation. If you try to delete user (any ForeignKey to TF model) then Django tries to find all related objects and it fails because Django will use base manager without FT args. I have fixed it in my private repo.

Now I am considering creating repo (lib) for TF Django support. But still not sure 😄

@fzzylogic
Copy link

fzzylogic commented Mar 23, 2021

Thanks for the info, and thanks for sharing! I've used it as a way of having models against TVF's, without having to fall back to RawQuerySet, so only using the 'Django Extending' part. Since so many DB's support TVF's, it makes sense for Django to do so too, so the repo you suggested sounds like a great idea. Even if not used often, sometimes TVF's are the best option.

@niccolomineo
Copy link

It's free to use.

Btw. there is problem in this implementation. If you try to delete user (any ForeignKey to TF model) then Django tries to find all related objects and it fails because Django will use base manager without FT args. I have fixed it in my private repo.

Now I am considering creating repo (lib) for TF Django support. But still not sure 😄

Hi, could you please share the fix?

@petrprikryl
Copy link
Author

petrprikryl commented Oct 22, 2021

class BaseManager(Manager):
    def get_queryset(self):
        return super().get_queryset().none()


class BaseTableModel(Model):
    objects = TableFunctionManager()
    base_objects = BaseManager()

    class Meta:
        base_manager_name = "base_objects"
        abstract = True

@niccolomineo
Copy link

Thanks, I really appreciate that.

@GandzyTM
Copy link

Hi all!

Has anyone encountered this problem when unloading data from a function table after MyModel.objects.all().table_function(id=1)?

django.db.utils.ProgrammingError: invalid reference to FROM-clause entry for table "udfconfigcompliance"
LINE 1: SELECT "pg"."udfconfigcompliance"."id", "pg"."udfconfigcompl...
               ^
HINT:  There is an entry for table "udfconfigcompliance", but it cannot be referenced from this part of the query.

but if i remove .get() i get <class 'collections.OrderedDict'>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment