Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@edelooff
Created June 12, 2020 22:50
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edelooff/b71436f137a390545b42d1f18a7ad4f8 to your computer and use it in GitHub Desktop.
Save edelooff/b71436f137a390545b42d1f18a7ad4f8 to your computer and use it in GitHub Desktop.
Evaluating SQLAlchemy expressions in Python (for ORM objects and such)
from __future__ import annotations
import operator
from collections import deque
from enum import Enum, auto
from typing import Any, Dict, Iterator, NamedTuple, Optional, Tuple
from sqlalchemy import Column
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import (
AsBoolean,
BinaryExpression,
BindParameter,
BooleanClauseList,
ColumnElement,
Grouping,
Null,
UnaryExpression,
)
OPERATOR_MAP = {
operators.istrue: None,
operators.isfalse: operator.not_,
}
class Expression:
def __init__(self, expression: ColumnElement, force_bool: bool = False):
self.serialized = tuple(self._serialize(expression, force_bool=force_bool))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.serialized == other.serialized
def evaluate(self, column_values: Dict[Column, Any]) -> Any:
stack = Stack()
for itype, arity, value in reversed(self.serialized):
if itype is SymbolType.literal:
stack.push(value)
elif itype is SymbolType.column:
stack.push(column_values[value])
else:
stack.push(value(*stack.popn(arity)))
return stack.pop()
def _serialize(self, expr, force_bool=False) -> Iterator[Symbol]:
"""Serializes an SQLAlchemy expression to Python functions.
This takes an SQLAlchemy expression tree and converts it into an
equivalent set of Python Symbols. The generated format is that
of a Polish prefix notation. This allows the expression to be easily
evaluated with column value substitutions.
"""
# Simple and direct value types
if isinstance(expr, BindParameter):
yield Symbol(expr.value)
elif isinstance(expr, Grouping):
value = [element.value for element in expr.element]
yield Symbol(value)
elif isinstance(expr, Null):
yield Symbol(None)
# Columns and column-wrapping functions
elif isinstance(expr, Column):
if force_bool:
yield from self._serialize(expr.isnot(None))
else:
yield Symbol(expr)
elif isinstance(expr, AsBoolean):
if (func := OPERATOR_MAP[expr.operator]) is not None:
yield Symbol(func, arity=1)
yield Symbol(expr.element)
elif isinstance(expr, UnaryExpression):
target = expr.element
target_is_column = isinstance(target, Column)
if force_bool and expr.operator == operator.inv and target_is_column:
yield from self._serialize(target.is_(None))
else:
yield Symbol(expr.operator, arity=1)
yield from self._serialize(target, force_bool=force_bool)
# Multi-clause expressions
elif isinstance(expr, BooleanClauseList):
yield Symbol(expr.operator, arity=len(expr.clauses))
for clause in expr.clauses:
yield from self._serialize(clause, force_bool=force_bool)
elif isinstance(expr, BinaryExpression):
yield Symbol(expr.operator, arity=2)
yield from self._serialize(expr.left)
yield from self._serialize(expr.right)
else:
raise TypeError(
f"Unsupported expression {expr} of type {type(expr)}.__name__"
)
class Stack:
def __init__(self):
self._stack = deque()
def push(self, frame: Any) -> None:
self._stack.append(frame)
def pop(self) -> Any:
return self._stack.pop()
def popn(self, argcount) -> Iterator[Any]:
return (self._stack.pop() for _ in range(argcount))
class Symbol:
__slots__ = "value", "type", "arity"
def __init__(self, value: Any, arity: int = None):
self.value = value
self.type = self._determine_type(value)
self.arity = arity
def _determine_type(self, value: Any) -> SymbolType:
if isinstance(value, Column):
return SymbolType.column
if callable(value):
return SymbolType.operator
return SymbolType.literal
def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return tuple(self) == tuple(other)
def __iter__(self):
yield from (self.type, self.arity, self.value)
class SymbolType(Enum):
column = auto()
literal = auto()
operator = auto()
import pytest
from sqlalchemy import Table, MetaData, Column, Boolean, Integer, Text, null
from expression import Expression
class TestBooleanExpressions:
BOOL = Table(
"bool_expr",
MetaData(),
Column("a", Boolean),
Column("b", Boolean),
Column("c", Boolean),
)
@pytest.fixture
def cols(self):
return self.BOOL.columns
@pytest.mark.parametrize(
"inputs, expected", [({BOOL.c.a: False}, False), ({BOOL.c.a: True}, True)]
)
def test_direct_bool(self, cols, inputs, expected):
expression = Expression(cols.a)
assert expression.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected", [({BOOL.c.a: False}, True), ({BOOL.c.a: True}, False)]
)
def test_negation(self, cols, inputs, expected):
expression = Expression(~cols.a)
assert expression.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected",
[
({BOOL.c.a: False, BOOL.c.b: False}, False),
({BOOL.c.a: True, BOOL.c.b: False}, False),
({BOOL.c.a: False, BOOL.c.b: True}, False),
({BOOL.c.a: True, BOOL.c.b: True}, True),
],
)
def test_conjunction(self, cols, inputs, expected):
expression = Expression(cols.a & cols.b)
assert expression.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected",
[
({BOOL.c.a: False, BOOL.c.b: False}, False),
({BOOL.c.a: True, BOOL.c.b: False}, True),
({BOOL.c.a: False, BOOL.c.b: True}, True),
({BOOL.c.a: True, BOOL.c.b: True}, True),
],
)
def test_disjunction(self, cols, inputs, expected):
expression = Expression(cols.a | cols.b)
assert expression.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected",
[
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: False}, False),
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: False}, True),
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: False}, False),
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: False}, False),
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: True}, True),
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: True}, True),
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: True}, True),
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: True}, True),
],
)
def test_mixed_expression(self, cols, inputs, expected):
expression = Expression((cols.a & ~cols.b) | cols.c)
assert expression.evaluate(inputs) == expected
class TestBooleanCoercion:
COERCE = Table(
"expr_coerce",
MetaData(),
Column("bool", Boolean),
Column("number", Integer),
Column("text", Text),
)
@pytest.fixture
def cols(self):
return self.COERCE.columns
def test_sql_equivalences(self, cols):
left = Expression(cols.text != null())
right = Expression(cols.text.isnot(None))
assert left == right
def test_coerce_simple_expression(self, cols):
left = Expression(cols.text, force_bool=True)
right = Expression(cols.text.isnot(None))
assert left == right
def test_coerce_negated_expression(self, cols):
left = Expression(~cols.text, force_bool=True)
right = Expression(cols.text.is_(None))
assert left == right
def test_do_not_coerce_nonbool(self, cols):
left = Expression(~cols.text.in_(["foo", "bar"]), force_bool=True)
right = Expression(~cols.text.in_(["foo", "bar"]))
assert left == right
class TestMathExpressions:
MATH = Table(
"expr_math",
MetaData(),
Column("a", Integer),
Column("b", Integer),
Column("c", Integer),
)
@pytest.fixture
def cols(self):
return self.MATH.columns
@pytest.mark.parametrize(
"inputs, expected",
[
({MATH.c.a: 0, MATH.c.b: 0}, 0),
({MATH.c.a: 2, MATH.c.b: 0}, 2),
({MATH.c.a: 0, MATH.c.b: 3}, 3),
({MATH.c.a: 5, MATH.c.b: 5}, 10),
({MATH.c.a: -2, MATH.c.b: -2}, -4),
],
)
def test_addition(self, cols, inputs, expected):
addition = Expression(cols.a + cols.b)
assert addition.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected",
[
({MATH.c.a: 0, MATH.c.b: 0}, 0),
({MATH.c.a: 2, MATH.c.b: 0}, 2),
({MATH.c.a: 0, MATH.c.b: 3}, -3),
({MATH.c.a: 10, MATH.c.b: 5}, 5),
],
)
def test_subtraction(self, cols, inputs, expected):
subtraction = Expression(cols.a - cols.b)
assert subtraction.evaluate(inputs) == expected
@pytest.mark.parametrize(
"inputs, expected",
[
({MATH.c.a: 0, MATH.c.b: 0, MATH.c.c: 0}, 0),
({MATH.c.a: 2, MATH.c.b: 3, MATH.c.c: 0}, 6),
({MATH.c.a: 3, MATH.c.b: 4, MATH.c.c: 6}, 6),
({MATH.c.a: 3, MATH.c.b: 3, MATH.c.c: 10}, -1),
],
)
def test_mixed_match(self, cols, inputs, expected):
multmin = Expression(cols.a * cols.b - cols.c)
assert multmin.evaluate(inputs) == expected
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment