Evaluating SQLAlchemy expressions in Python (for ORM objects and such)
from __future__ import annotations | |
import operator | |
from collections import deque | |
from enum import Enum, auto | |
from typing import Any, Dict, Iterator, NamedTuple, Optional, Tuple | |
from sqlalchemy import Column | |
from sqlalchemy.sql import operators | |
from sqlalchemy.sql.elements import ( | |
AsBoolean, | |
BinaryExpression, | |
BindParameter, | |
BooleanClauseList, | |
ColumnElement, | |
Grouping, | |
Null, | |
UnaryExpression, | |
) | |
OPERATOR_MAP = { | |
operators.istrue: None, | |
operators.isfalse: operator.not_, | |
} | |
class Expression: | |
def __init__(self, expression: ColumnElement, force_bool: bool = False): | |
self.serialized = tuple(self._serialize(expression, force_bool=force_bool)) | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, type(self)): | |
return NotImplemented | |
return self.serialized == other.serialized | |
def evaluate(self, column_values: Dict[Column, Any]) -> Any: | |
stack = Stack() | |
for itype, arity, value in reversed(self.serialized): | |
if itype is SymbolType.literal: | |
stack.push(value) | |
elif itype is SymbolType.column: | |
stack.push(column_values[value]) | |
else: | |
stack.push(value(*stack.popn(arity))) | |
return stack.pop() | |
def _serialize(self, expr, force_bool=False) -> Iterator[Symbol]: | |
"""Serializes an SQLAlchemy expression to Python functions. | |
This takes an SQLAlchemy expression tree and converts it into an | |
equivalent set of Python Symbols. The generated format is that | |
of a Polish prefix notation. This allows the expression to be easily | |
evaluated with column value substitutions. | |
""" | |
# Simple and direct value types | |
if isinstance(expr, BindParameter): | |
yield Symbol(expr.value) | |
elif isinstance(expr, Grouping): | |
value = [element.value for element in expr.element] | |
yield Symbol(value) | |
elif isinstance(expr, Null): | |
yield Symbol(None) | |
# Columns and column-wrapping functions | |
elif isinstance(expr, Column): | |
if force_bool: | |
yield from self._serialize(expr.isnot(None)) | |
else: | |
yield Symbol(expr) | |
elif isinstance(expr, AsBoolean): | |
if (func := OPERATOR_MAP[expr.operator]) is not None: | |
yield Symbol(func, arity=1) | |
yield Symbol(expr.element) | |
elif isinstance(expr, UnaryExpression): | |
target = expr.element | |
target_is_column = isinstance(target, Column) | |
if force_bool and expr.operator == operator.inv and target_is_column: | |
yield from self._serialize(target.is_(None)) | |
else: | |
yield Symbol(expr.operator, arity=1) | |
yield from self._serialize(target, force_bool=force_bool) | |
# Multi-clause expressions | |
elif isinstance(expr, BooleanClauseList): | |
yield Symbol(expr.operator, arity=len(expr.clauses)) | |
for clause in expr.clauses: | |
yield from self._serialize(clause, force_bool=force_bool) | |
elif isinstance(expr, BinaryExpression): | |
yield Symbol(expr.operator, arity=2) | |
yield from self._serialize(expr.left) | |
yield from self._serialize(expr.right) | |
else: | |
raise TypeError( | |
f"Unsupported expression {expr} of type {type(expr)}.__name__" | |
) | |
class Stack: | |
def __init__(self): | |
self._stack = deque() | |
def push(self, frame: Any) -> None: | |
self._stack.append(frame) | |
def pop(self) -> Any: | |
return self._stack.pop() | |
def popn(self, argcount) -> Iterator[Any]: | |
return (self._stack.pop() for _ in range(argcount)) | |
class Symbol: | |
__slots__ = "value", "type", "arity" | |
def __init__(self, value: Any, arity: int = None): | |
self.value = value | |
self.type = self._determine_type(value) | |
self.arity = arity | |
def _determine_type(self, value: Any) -> SymbolType: | |
if isinstance(value, Column): | |
return SymbolType.column | |
if callable(value): | |
return SymbolType.operator | |
return SymbolType.literal | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, type(self)): | |
return NotImplemented | |
return tuple(self) == tuple(other) | |
def __iter__(self): | |
yield from (self.type, self.arity, self.value) | |
class SymbolType(Enum): | |
column = auto() | |
literal = auto() | |
operator = auto() |
import pytest | |
from sqlalchemy import Table, MetaData, Column, Boolean, Integer, Text, null | |
from expression import Expression | |
class TestBooleanExpressions: | |
BOOL = Table( | |
"bool_expr", | |
MetaData(), | |
Column("a", Boolean), | |
Column("b", Boolean), | |
Column("c", Boolean), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.BOOL.columns | |
@pytest.mark.parametrize( | |
"inputs, expected", [({BOOL.c.a: False}, False), ({BOOL.c.a: True}, True)] | |
) | |
def test_direct_bool(self, cols, inputs, expected): | |
expression = Expression(cols.a) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", [({BOOL.c.a: False}, True), ({BOOL.c.a: True}, False)] | |
) | |
def test_negation(self, cols, inputs, expected): | |
expression = Expression(~cols.a) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False}, False), | |
({BOOL.c.a: False, BOOL.c.b: True}, False), | |
({BOOL.c.a: True, BOOL.c.b: True}, True), | |
], | |
) | |
def test_conjunction(self, cols, inputs, expected): | |
expression = Expression(cols.a & cols.b) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False}, True), | |
({BOOL.c.a: False, BOOL.c.b: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: True}, True), | |
], | |
) | |
def test_disjunction(self, cols, inputs, expected): | |
expression = Expression(cols.a | cols.b) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: False}, True), | |
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: False}, False), | |
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: True}, True), | |
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: True}, True), | |
], | |
) | |
def test_mixed_expression(self, cols, inputs, expected): | |
expression = Expression((cols.a & ~cols.b) | cols.c) | |
assert expression.evaluate(inputs) == expected | |
class TestBooleanCoercion: | |
COERCE = Table( | |
"expr_coerce", | |
MetaData(), | |
Column("bool", Boolean), | |
Column("number", Integer), | |
Column("text", Text), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.COERCE.columns | |
def test_sql_equivalences(self, cols): | |
left = Expression(cols.text != null()) | |
right = Expression(cols.text.isnot(None)) | |
assert left == right | |
def test_coerce_simple_expression(self, cols): | |
left = Expression(cols.text, force_bool=True) | |
right = Expression(cols.text.isnot(None)) | |
assert left == right | |
def test_coerce_negated_expression(self, cols): | |
left = Expression(~cols.text, force_bool=True) | |
right = Expression(cols.text.is_(None)) | |
assert left == right | |
def test_do_not_coerce_nonbool(self, cols): | |
left = Expression(~cols.text.in_(["foo", "bar"]), force_bool=True) | |
right = Expression(~cols.text.in_(["foo", "bar"])) | |
assert left == right | |
class TestMathExpressions: | |
MATH = Table( | |
"expr_math", | |
MetaData(), | |
Column("a", Integer), | |
Column("b", Integer), | |
Column("c", Integer), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.MATH.columns | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 0}, 2), | |
({MATH.c.a: 0, MATH.c.b: 3}, 3), | |
({MATH.c.a: 5, MATH.c.b: 5}, 10), | |
({MATH.c.a: -2, MATH.c.b: -2}, -4), | |
], | |
) | |
def test_addition(self, cols, inputs, expected): | |
addition = Expression(cols.a + cols.b) | |
assert addition.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 0}, 2), | |
({MATH.c.a: 0, MATH.c.b: 3}, -3), | |
({MATH.c.a: 10, MATH.c.b: 5}, 5), | |
], | |
) | |
def test_subtraction(self, cols, inputs, expected): | |
subtraction = Expression(cols.a - cols.b) | |
assert subtraction.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0, MATH.c.c: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 3, MATH.c.c: 0}, 6), | |
({MATH.c.a: 3, MATH.c.b: 4, MATH.c.c: 6}, 6), | |
({MATH.c.a: 3, MATH.c.b: 3, MATH.c.c: 10}, -1), | |
], | |
) | |
def test_mixed_match(self, cols, inputs, expected): | |
multmin = Expression(cols.a * cols.b - cols.c) | |
assert multmin.evaluate(inputs) == expected |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment