Created
December 24, 2012 23:05
-
-
Save henriquebastos/4370992 to your computer and use it in GitHub Desktop.
Simple implementation of Conditional Aggregates in Django
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
''' | |
Implements conditional aggregates. | |
This code was based on the work of others found on the internet: | |
1. http://web.archive.org/web/20101115170804/http://www.voteruniverse.com/Members/jlantz/blog/conditional-aggregates-in-django | |
2. https://code.djangoproject.com/ticket/11305 | |
3. https://groups.google.com/forum/?fromgroups=#!topic/django-users/cjzloTUwmS0 | |
4. https://groups.google.com/forum/?fromgroups=#!topic/django-users/vVprMpsAnPo | |
''' | |
from django.db.models.aggregates import Aggregate as DjangoAggregate | |
from django.db.models.sql.aggregates import Aggregate as DjangoSqlAggregate | |
class SqlAggregate(DjangoSqlAggregate): | |
conditional_template = '%(function)s(CASE WHEN %(condition)s THEN %(field)s ELSE null END)' | |
def __init__(self, col, source=None, is_summary=False, condition=None, **extra): | |
super(SqlAggregate, self).__init__(col, source, is_summary, **extra) | |
self.condition = condition | |
def relabel_aliases(self, change_map): | |
super(SqlAggregate, self).relabel_aliases(change_map) | |
if self.has_condition: | |
self.condition.relabel_aliases(change_map) | |
def as_sql(self, qn, connection): | |
if self.has_condition: | |
self.sql_template = self.conditional_template | |
self.extra['condition'] = self._condition_as_sql(qn, connection) | |
return super(SqlAggregate, self).as_sql(qn, connection) | |
@property | |
def has_condition(self): | |
# Warning: bool(QuerySet) will hit the database | |
return self.condition is not None | |
def _condition_as_sql(self, qn, connection): | |
''' | |
Return sql for condition. | |
''' | |
def escape(value): | |
if isinstance(value, basestring): | |
value = qn(value) | |
if isinstance(value, bool): | |
value = int(value) | |
return value | |
sql, param = self.condition.query.where.as_sql(qn, connection) | |
param = map(escape, param) | |
return sql % tuple(param) | |
class SqlSum(SqlAggregate): | |
sql_function = 'SUM' | |
class SqlCount(SqlAggregate): | |
sql_function = 'COUNT' | |
class Aggregate(DjangoAggregate): | |
def __init__(self, lookup, only=None, **extra): | |
super(Aggregate, self).__init__(lookup, **extra) | |
self.only = only | |
self.condition = None | |
def add_to_query(self, query, alias, col, source, is_summary): | |
if self.only: | |
self.condition = query.model._default_manager.filter(self.only) | |
aggregate = self.sql_klass(col, source, is_summary, self.condition, **self.extra) | |
query.aggregates[alias] = aggregate | |
class Sum(Aggregate): | |
name = 'Sum' | |
sql_klass = SqlSum | |
class Count(Aggregate): | |
name = 'Count' | |
sql_klass = SqlCount |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment