Skip to content

Instantly share code, notes, and snippets.

@everilae
Last active August 29, 2015 14:03
Show Gist options
  • Save everilae/3d746e090a3084324316 to your computer and use it in GitHub Desktop.
Save everilae/3d746e090a3084324316 to your computer and use it in GitHub Desktop.
An attempt at aggregate FILTER clauses for SQLAlchemy
import itertools
from sqlalchemy import util, and_, case
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.elements import ClauseList, _clone
from sqlalchemy.sql.functions import FunctionElement
class AggregateFilter(ColumnElement):
"""Represent a FILTER clause.
This is a special operator against aggregate functions,
which produces results relative to the result set
itself. It's supported only by certain database
backends.
"""
__visit_name__ = 'aggregatefilter'
criterion = None
def __init__(self, func, *criterion):
"""Produce an :class:`.AggregateFilter` object against a function.
Used against aggregate functions,
for database backends that support aggregate "FILTER" clause.
E.g.::
from sqlalchemy import over
filter(func.count(1), MyClass.name == 'some name')
Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
:param func: a :class:`.FunctionElement` construct, typically
generated by :data:`~.expression.func`.
:param criterion: a column element or string, or a list
of such, that will be used as the FILTER clause
of the aggregate construct.
This function is also available from the :data:`~.expression.func`
construct itself via the :meth:`.FunctionElement.filter` method.
"""
self.func = func
if criterion:
self.criterion = ClauseList(*util.to_list(criterion))
@util.memoized_property
def type(self):
return self.func.type
def get_children(self, **kwargs):
return [c for c in
(self.func, self.criterion)
if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.func = clone(self.func, **kw)
if self.criterion is not None:
self.criterion = clone(self.criterion, **kw)
@property
def _from_objects(self):
return list(itertools.chain(
*[c._from_objects for c in
(self.func, self.criterion)
if c is not None]
))
# FIXME: this skips the normal function compilation possibly causing all kinds of
# weird or unexpected behaviour. Seems to work for simple count() and sum() cases.
@compiles(AggregateFilter)
def visit_aggregatefilter(aggfilter, compiler, **kwargs):
return "%s(%s)" % (
".".join(list(aggfilter.func.packagenames) + [aggfilter.func.name]),
compiler.process(
case([
(
and_(*aggfilter.criterion),
aggfilter.func.clause_expr
)
])
)
)
# Uncomment to enable, if using postgresql >= 9.4
#@compiles(AggregateFilter, "postgresql")
def pg_visit_aggregatefilter(aggfilter, compiler, **kwargs):
return "%s FILTER (WHERE %s)" % (
compiler.process(aggfilter.func),
compiler.process(and_(*aggfilter.criterion))
)
def filter_(self, *criterion):
"""Produce a FILTER clause against this function.
Used against aggregate functions,
for database backends that support aggregate "FILTER" clause.
"""
return AggregateFilter(self, *criterion)
# Monkeypatching, uncomment to enable
#FunctionElement.filter = filter_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment