Last active
June 10, 2022 19:28
-
-
Save DavidBuchanan314/1507cb241878578005ce8631dcd1d82b to your computer and use it in GitHub Desktop.
A very basic implementation of symbolic expressions. Operator overloading is fun!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
from abc import ABC, abstractmethod | |
class Expression(ABC): | |
type = None | |
@staticmethod | |
def cast(value): | |
if isinstance(value, Expression): | |
return value | |
return Literal(value) | |
def __mul__(self, other): | |
return BinaryOp(operator.mul, self, other) | |
def __add__(self, other): | |
return BinaryOp(operator.add, self, other) | |
def __sub__(self, other): | |
return BinaryOp(operator.sub, self, other) | |
def __neg__(self): | |
return UnaryOp(operator.neg, self) | |
def __rshift__(self, other): | |
return BinaryOp(operator.rshift, self, other) | |
def __lshift__(self, other): | |
return BinaryOp(operator.lshift, self, other) | |
def __and__(self, other): | |
return BinaryOp(operator.and_, self, other) | |
@abstractmethod | |
def evaluate(self, symbols): | |
"""`symbols` should be a dictionary mapping symbols to values | |
(which may or may not be expressions themselves). | |
Should return a numeric value""" | |
class UnaryOp(Expression): | |
def __init__(self, operator, operand): | |
self.operator = operator | |
self.operand = Expression.cast(operand) | |
self.type = self.operand.type | |
def evaluate(self, symbols): | |
return self.operator(self.operand.evaluate(symbols)) | |
class BinaryOp(Expression): | |
def __init__(self, operator, left, right): | |
self.operator = operator | |
self.left = Expression.cast(left) | |
self.right = Expression.cast(right) | |
# propagate type info if present, giving priority to the type of the lval | |
self.type = self.right.type if self.left.type is None else self.left.type | |
def evaluate(self, symbols): | |
return self.operator( | |
self.left.evaluate(symbols), | |
self.right.evaluate(symbols) | |
) | |
class Literal(Expression): | |
def __init__(self, value): | |
self.value = int(value) # we only support int literals, for now | |
def evaluate(self, symbols): | |
_ = symbols # unused | |
return self.value | |
class Symbol(Expression): | |
def __init__(self, name, type=None): | |
self.name = name | |
self.type = type | |
def evaluate(self, symbols): | |
return Expression.cast(symbols[self.name]).evaluate(symbols) | |
class SymbolFactory: | |
def __init__(self, type=None): | |
self.type = type | |
def __getattribute__(self, name): | |
# we can't just do self.type because that'd call __getattribute__ | |
# and then we'd get infinite recursion... | |
type_ = object.__getattribute__(self, "type") | |
return Symbol(name, type=type_) | |
if __name__ == "__main__": | |
sym = SymbolFactory(type="foobar") | |
expression = sym.x * 5 + 9 - 2 | |
print(7 * 5 + 9 - 2) | |
print(expression.evaluate({"x": 7})) | |
print(expression.type) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment