Last active
April 10, 2020 05:43
-
-
Save amidvidy/16433b938a629687a478dc81618184a9 to your computer and use it in GitHub Desktop.
polynomial evaluator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Mapping, List, Optional, Set | |
from collections import Counter | |
from functools import reduce | |
class Expr: | |
# evaluate this expression, with respect to the following variable substitutions | |
def eval(self, env: Mapping[str, float]) -> float: | |
raise NotImplementedError() | |
# take a derivative with respect to a variable. | |
def deriv(self, by: str) -> 'Expr': | |
raise NotImplementedError() | |
# print the expression. | |
def __str__(self) -> str: | |
raise NotImplementedError() | |
# True if the expression is constant. Used for constant folding. | |
def is_const(self) -> bool: | |
raise NotImplementedError() | |
def __eq__(self, other) -> bool: | |
raise NotImplementedError() | |
def __hash__(self): | |
# bit of a hack but it works well enough. | |
return hash(self.__str__()) | |
# static factory method, used since we implement some optimizations and in some cases | |
# actually return a different class than the one we are calling Make on. | |
@staticmethod | |
def make(self) -> 'Expr': | |
raise NotImplementedError() | |
class Const(Expr): | |
def __init__(self, value): | |
self._value = float(value) | |
super().__init__() | |
def eval(self, env: Mapping[str, float]) -> float: | |
return self._value | |
def deriv(self, by: str) -> Expr: | |
return Const(0.0) | |
def is_const(self) -> bool: | |
return True | |
def __str__(self): | |
return f'{self._value}' | |
def __eq__(self, other) -> bool: | |
return isinstance(other, Const) and other._value == self._value | |
def __hash__(self): | |
return hash(self._value) | |
@staticmethod | |
def make(value: float) -> Expr: | |
return Const(value) | |
def constify(e: Expr) -> Expr: | |
if e.is_const(): | |
return Const(e.eval({})) | |
return e | |
class Var(Expr): | |
def __init__(self, name: str): | |
self._name = name | |
super().__init__() | |
def eval(self, env: Mapping[str, float]) -> float: | |
if not self._name in env: | |
raise Exception(f'Undefined variable: {self._name}') | |
return env[self._name] | |
def deriv(self, by: str) -> Expr: | |
if by == self._name: | |
return Const(1.0) | |
return Const(0.0) | |
def is_const(self) -> bool: | |
return False | |
def __str__(self): | |
return f'{self._name}' | |
def __eq__(self, other) -> bool: | |
return isinstance(other, Var) and other._name == self._name | |
def __hash__(self): | |
return hash(self._name) | |
@staticmethod | |
def make(name: str): | |
return Var(name) | |
class Sum(Expr): | |
def __init__(self, exprs: List[Expr], const: Const): | |
# all exprs must be non-constant. | |
for expr in exprs: | |
assert not expr.is_const() | |
assert const.is_const() | |
self._exprs = exprs | |
self._const = const | |
super().__init__() | |
def eval(self, env: Mapping[str, float]) -> float: | |
return sum(expr.eval(env) for expr in self._exprs) + self._const.eval(env) | |
def deriv(self, by: str) -> Expr: | |
return Sum.make([expr.deriv(by) for expr in self._exprs]) | |
def is_const(self) -> bool: | |
return len(self._exprs) == 0 | |
def __eq__(self, other) -> bool: | |
return (isinstance(other, Sum) | |
and self._const == other._const | |
and all(lhs == rhs for lhs, rhs in zip(self._exprs, other._exprs))) | |
def __hash__(self): | |
return hash(frozenset(self._exprs + [self._const])) | |
def __str__(self): | |
exprs = self._exprs | |
if self._const.eval({}) != 0.0: | |
exprs = exprs + [self._const] | |
e = '+'.join(str(e) for e in exprs) | |
return f'{e}' | |
@staticmethod | |
def make(exprs: List[Expr]) -> Expr: | |
# few steps here. | |
# (1) partition by constant and non-constant. | |
const_exprs = filter(lambda e: e.is_const(), exprs) | |
nonconst_exprs = filter(lambda e: not e.is_const(), exprs) | |
# (2) if any of the non-constant expressions are sums, merge them with this one (associativity) | |
exprs = [] | |
consts = const_exprs | |
for e in nonconst_exprs: | |
if isinstance(e, Sum): | |
exprs.extend(e._exprs) | |
consts.append(e._const) | |
else: | |
exprs.append(e) | |
# (3) evaluate the constant expressions and add them together. | |
const = float(sum(c.eval({}) for c in consts)) | |
# (4) if no non-constant epxrs, return a constant. | |
if not exprs: | |
return Const(const) | |
# (5) # convert groups of sums of same expr to multiplication e.g. (4*x + 2*x) to 6x (distributivity) | |
coeffs = Counter() | |
for expr in exprs: | |
assert not expr.is_const() | |
if isinstance(expr, Product): | |
base = Product(expr._exprs, Const.make(1.0)) | |
coeffs[base] += expr._const.eval({}) | |
else: | |
coeffs[expr] += 1.0 | |
collapsed = [Product.make([Const.make(c), expr]) | |
for expr, c in coeffs.items()] | |
if len(collapsed) == 1 and const == 0.0: | |
return collapsed[0] | |
return Sum(collapsed, Const.make(const)) | |
class Product(Expr): | |
def __init__(self, exprs: List[Expr], const: Const): | |
for expr in exprs: | |
assert not expr.is_const() | |
assert const.is_const() | |
self._exprs = exprs | |
self._const = const | |
super().__init__() | |
def eval(self, env: Mapping[str, float]) -> float: | |
vals = [expr.eval(env) for expr in self._exprs] | |
return reduce(lambda x, y: x*y, vals) * self._const.eval({}) | |
def deriv(self, by: str) -> Expr: | |
out = [] | |
exprs = self._exprs + [self._const] | |
for i, iexpr in enumerate(exprs): | |
cur_prod = [] | |
for j, jexpr in enumerate(exprs): | |
if i == j: | |
cur_prod.append(jexpr.deriv(by=by)) | |
else: | |
cur_prod.append(jexpr) | |
out.append(Product.make(cur_prod)) | |
s = Sum.make(out) | |
return s | |
def __str__(self) -> str: | |
exprs = self._exprs | |
if self._const.eval({}) != 1.0: | |
exprs = [self._const] + exprs | |
e = '*'.join(str(e) for e in exprs) | |
return f'{e}' | |
def __hash__(self): | |
return hash(frozenset(self._exprs + [self._const])) | |
def __repr__(self) -> str: | |
return str(self) | |
def __eq__(self, other) -> bool: | |
return (isinstance(other, Product) | |
and self._const == other._const | |
and all(lhs == rhs for lhs, rhs in zip(self._exprs, other._exprs))) | |
def is_const(self) -> bool: | |
assert len(self._exprs) > 0 | |
return False | |
@staticmethod | |
def make(exprs: List[Expr]) -> Expr: | |
# Note this is very similar to the make method of Sum but different enough | |
# that it wasn't worth deduplicating the code. Eventually if more operators are | |
# added it might make sense to add a NaryOperator base class that has some behavior | |
# that depends on whether the operator is associative and distributive. | |
# (1) partition by constant and non-constant. | |
const_exprs = list(filter(lambda e: e.is_const(), exprs)) | |
nonconst_exprs = list(filter(lambda e: not e.is_const(), exprs)) | |
# (2) if any of the non-constant expressions are products, | |
# merge them with this one (associativity) | |
exprs = [] | |
consts = const_exprs | |
for e in nonconst_exprs: | |
if isinstance(e, Product): | |
exprs.extend(e._exprs) | |
consts.append(e._const) | |
else: | |
exprs.append(e) | |
const = 1.0 | |
if consts: | |
const = float(reduce(lambda x, y: x*y, (c.eval({}) | |
for c in consts))) | |
# short circuit if we have a term of 0.0 | |
if const == 0.0: | |
return Const(0.0) | |
# short | |
if not exprs: | |
return Const(const) | |
# convert groups of products of the same expr to exponentiation. e.g. | |
# x * x^2 => x^3, x*x => x^2 etc. | |
coeffs = Counter() | |
for expr in exprs: | |
assert not expr.is_const() | |
if isinstance(expr, Exp): | |
coeffs[expr._base] += expr._exp.eval({}) | |
else: | |
coeffs[expr] += 1.0 | |
collapsed = [Exp.make(base, Const.make(exp)) | |
for base, exp in coeffs.items()] | |
if len(collapsed) == 1 and const == 1.0: | |
return collapsed[0] | |
return Product(collapsed, Const.make(const)) | |
class Exp(Expr): | |
def __init__(self, base: Expr, exp: Expr): | |
self._base = constify(base) | |
self._exp = constify(exp) | |
assert self._exp.is_const( | |
), 'only constant expressions are supported for the exponent currently.' | |
super().__init__() | |
def eval(self, env: Mapping[str, float]) -> float: | |
ebase = self._base.eval(env) | |
eexp = self._exp.eval(env) | |
return ebase ** eexp | |
def deriv(self, by: str) -> Expr: | |
return Product.make([ | |
self._exp, | |
Exp.make(self._base, | |
Sum.make([self._exp, Const(-1.0)])) | |
]) | |
@staticmethod | |
def make(base: Expr, exp: Expr) -> Expr: | |
if exp.is_const(): | |
eexp = exp.eval({}) | |
if eexp == 0.0: | |
return Const.make(0.0) | |
elif eexp == 1.0: | |
return base | |
return constify(Exp(base, exp)) | |
def is_const(self) -> bool: | |
return self._base.is_const() and self._exp.is_const() | |
def __str__(self): | |
return f'{self._base}^{self._exp}' | |
if __name__ == '__main__': | |
x = Var('x') | |
term0 = Product.make([Const.make(2.0), Exp.make(x, Const.make(3.0))]) | |
term1 = Product.make([Const.make(-3.0), Exp.make(x, Const.make(2.0))]) | |
term2 = Product.make([Const.make(4.0), x]) | |
term3 = Const.make(5.0) | |
term4 = Product.make([Const.make(5.0), x]) | |
term5 = Product.make([x, x]) | |
poly = Sum.make([term0, term1, term2, term3]) | |
print(poly) | |
poly2 = Sum.make([term2, term4]) | |
print(poly2) | |
print(term5) | |
print(term5.deriv(by='x')) | |
print(poly.deriv(by='x')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment