Skip to content

Instantly share code, notes, and snippets.

@funnydman
Created December 1, 2020 12:56
Show Gist options
  • Save funnydman/376480af6500cc43f706df94e6e6ab33 to your computer and use it in GitHub Desktop.
Save funnydman/376480af6500cc43f706df94e6e6ab33 to your computer and use it in GitHub Desktop.
combined expressions Django helper
# https://stackoverflow.com/questions/58877390/how-to-collect-results-into-array-from-annotation
"""
Some different ways:
"""
class CombinedExpressions(SQLiteNumericMixin, Expression):
def __init__(self, *expressions, connector, output_field=None):
super().__init__(output_field=output_field)
self.connector = connector
self.expressions = expressions
def get_source_expressions(self):
return self.expressions
def as_sql(self, compiler, connection):
output_fields = []
not_supported_fields = {'DateField', 'DateTimeField', 'TimeField', 'DurationField'}
for expr in self.expressions:
output_field = getattr(expr, 'output_field')
if output_field:
output_fields.append(output_field.get_internal_type())
if set(output_fields) & not_supported_fields:
raise NotImplementedError(f'Not support for such output fields: {", ".join(not_supported_fields)}')
expressions = []
expression_params = []
# compile every expression
for expression in self.expressions:
sql, params = compiler.compile(expression)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = '(%s)'
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
# resolving stuff
expressions = []
for exp in c.expressions:
exp = exp.resolve_expression(query, allow_joins, reuse, summarize, for_save)
expressions.append(exp)
c.expressions = expressions
return c
class FruitManager(Manager):
def get_queryset(self):
query = super().get_queryset()
query = query.annotate(
result=(
CombinedExpressions(
ArrayAgg(Case(
When(
type='tropicals',
then=Value('This fruit is tropical...'),
),
output_field=CharField()
)),
ArrayAgg(Case(
When(
country_of_import='Africa',
then=Value('This fruit is citrus...'),
),
output_field=CharField(),
default=Value('here we go')),
),
connector='||'
)
)
)
return query
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment