Skip to content

Instantly share code, notes, and snippets.

@amidvidy
Last active April 10, 2020 05:43
Show Gist options
  • Save amidvidy/16433b938a629687a478dc81618184a9 to your computer and use it in GitHub Desktop.
Save amidvidy/16433b938a629687a478dc81618184a9 to your computer and use it in GitHub Desktop.
polynomial evaluator
from typing import Mapping, List, Optional, Set
from collections import Counter
from functools import reduce
class Expr:
# evaluate this expression, with respect to the following variable substitutions
def eval(self, env: Mapping[str, float]) -> float:
raise NotImplementedError()
# take a derivative with respect to a variable.
def deriv(self, by: str) -> 'Expr':
raise NotImplementedError()
# print the expression.
def __str__(self) -> str:
raise NotImplementedError()
# True if the expression is constant. Used for constant folding.
def is_const(self) -> bool:
raise NotImplementedError()
def __eq__(self, other) -> bool:
raise NotImplementedError()
def __hash__(self):
# bit of a hack but it works well enough.
return hash(self.__str__())
# static factory method, used since we implement some optimizations and in some cases
# actually return a different class than the one we are calling Make on.
@staticmethod
def make(self) -> 'Expr':
raise NotImplementedError()
class Const(Expr):
def __init__(self, value):
self._value = float(value)
super().__init__()
def eval(self, env: Mapping[str, float]) -> float:
return self._value
def deriv(self, by: str) -> Expr:
return Const(0.0)
def is_const(self) -> bool:
return True
def __str__(self):
return f'{self._value}'
def __eq__(self, other) -> bool:
return isinstance(other, Const) and other._value == self._value
def __hash__(self):
return hash(self._value)
@staticmethod
def make(value: float) -> Expr:
return Const(value)
def constify(e: Expr) -> Expr:
if e.is_const():
return Const(e.eval({}))
return e
class Var(Expr):
def __init__(self, name: str):
self._name = name
super().__init__()
def eval(self, env: Mapping[str, float]) -> float:
if not self._name in env:
raise Exception(f'Undefined variable: {self._name}')
return env[self._name]
def deriv(self, by: str) -> Expr:
if by == self._name:
return Const(1.0)
return Const(0.0)
def is_const(self) -> bool:
return False
def __str__(self):
return f'{self._name}'
def __eq__(self, other) -> bool:
return isinstance(other, Var) and other._name == self._name
def __hash__(self):
return hash(self._name)
@staticmethod
def make(name: str):
return Var(name)
class Sum(Expr):
def __init__(self, exprs: List[Expr], const: Const):
# all exprs must be non-constant.
for expr in exprs:
assert not expr.is_const()
assert const.is_const()
self._exprs = exprs
self._const = const
super().__init__()
def eval(self, env: Mapping[str, float]) -> float:
return sum(expr.eval(env) for expr in self._exprs) + self._const.eval(env)
def deriv(self, by: str) -> Expr:
return Sum.make([expr.deriv(by) for expr in self._exprs])
def is_const(self) -> bool:
return len(self._exprs) == 0
def __eq__(self, other) -> bool:
return (isinstance(other, Sum)
and self._const == other._const
and all(lhs == rhs for lhs, rhs in zip(self._exprs, other._exprs)))
def __hash__(self):
return hash(frozenset(self._exprs + [self._const]))
def __str__(self):
exprs = self._exprs
if self._const.eval({}) != 0.0:
exprs = exprs + [self._const]
e = '+'.join(str(e) for e in exprs)
return f'{e}'
@staticmethod
def make(exprs: List[Expr]) -> Expr:
# few steps here.
# (1) partition by constant and non-constant.
const_exprs = filter(lambda e: e.is_const(), exprs)
nonconst_exprs = filter(lambda e: not e.is_const(), exprs)
# (2) if any of the non-constant expressions are sums, merge them with this one (associativity)
exprs = []
consts = const_exprs
for e in nonconst_exprs:
if isinstance(e, Sum):
exprs.extend(e._exprs)
consts.append(e._const)
else:
exprs.append(e)
# (3) evaluate the constant expressions and add them together.
const = float(sum(c.eval({}) for c in consts))
# (4) if no non-constant epxrs, return a constant.
if not exprs:
return Const(const)
# (5) # convert groups of sums of same expr to multiplication e.g. (4*x + 2*x) to 6x (distributivity)
coeffs = Counter()
for expr in exprs:
assert not expr.is_const()
if isinstance(expr, Product):
base = Product(expr._exprs, Const.make(1.0))
coeffs[base] += expr._const.eval({})
else:
coeffs[expr] += 1.0
collapsed = [Product.make([Const.make(c), expr])
for expr, c in coeffs.items()]
if len(collapsed) == 1 and const == 0.0:
return collapsed[0]
return Sum(collapsed, Const.make(const))
class Product(Expr):
def __init__(self, exprs: List[Expr], const: Const):
for expr in exprs:
assert not expr.is_const()
assert const.is_const()
self._exprs = exprs
self._const = const
super().__init__()
def eval(self, env: Mapping[str, float]) -> float:
vals = [expr.eval(env) for expr in self._exprs]
return reduce(lambda x, y: x*y, vals) * self._const.eval({})
def deriv(self, by: str) -> Expr:
out = []
exprs = self._exprs + [self._const]
for i, iexpr in enumerate(exprs):
cur_prod = []
for j, jexpr in enumerate(exprs):
if i == j:
cur_prod.append(jexpr.deriv(by=by))
else:
cur_prod.append(jexpr)
out.append(Product.make(cur_prod))
s = Sum.make(out)
return s
def __str__(self) -> str:
exprs = self._exprs
if self._const.eval({}) != 1.0:
exprs = [self._const] + exprs
e = '*'.join(str(e) for e in exprs)
return f'{e}'
def __hash__(self):
return hash(frozenset(self._exprs + [self._const]))
def __repr__(self) -> str:
return str(self)
def __eq__(self, other) -> bool:
return (isinstance(other, Product)
and self._const == other._const
and all(lhs == rhs for lhs, rhs in zip(self._exprs, other._exprs)))
def is_const(self) -> bool:
assert len(self._exprs) > 0
return False
@staticmethod
def make(exprs: List[Expr]) -> Expr:
# Note this is very similar to the make method of Sum but different enough
# that it wasn't worth deduplicating the code. Eventually if more operators are
# added it might make sense to add a NaryOperator base class that has some behavior
# that depends on whether the operator is associative and distributive.
# (1) partition by constant and non-constant.
const_exprs = list(filter(lambda e: e.is_const(), exprs))
nonconst_exprs = list(filter(lambda e: not e.is_const(), exprs))
# (2) if any of the non-constant expressions are products,
# merge them with this one (associativity)
exprs = []
consts = const_exprs
for e in nonconst_exprs:
if isinstance(e, Product):
exprs.extend(e._exprs)
consts.append(e._const)
else:
exprs.append(e)
const = 1.0
if consts:
const = float(reduce(lambda x, y: x*y, (c.eval({})
for c in consts)))
# short circuit if we have a term of 0.0
if const == 0.0:
return Const(0.0)
# short
if not exprs:
return Const(const)
# convert groups of products of the same expr to exponentiation. e.g.
# x * x^2 => x^3, x*x => x^2 etc.
coeffs = Counter()
for expr in exprs:
assert not expr.is_const()
if isinstance(expr, Exp):
coeffs[expr._base] += expr._exp.eval({})
else:
coeffs[expr] += 1.0
collapsed = [Exp.make(base, Const.make(exp))
for base, exp in coeffs.items()]
if len(collapsed) == 1 and const == 1.0:
return collapsed[0]
return Product(collapsed, Const.make(const))
class Exp(Expr):
def __init__(self, base: Expr, exp: Expr):
self._base = constify(base)
self._exp = constify(exp)
assert self._exp.is_const(
), 'only constant expressions are supported for the exponent currently.'
super().__init__()
def eval(self, env: Mapping[str, float]) -> float:
ebase = self._base.eval(env)
eexp = self._exp.eval(env)
return ebase ** eexp
def deriv(self, by: str) -> Expr:
return Product.make([
self._exp,
Exp.make(self._base,
Sum.make([self._exp, Const(-1.0)]))
])
@staticmethod
def make(base: Expr, exp: Expr) -> Expr:
if exp.is_const():
eexp = exp.eval({})
if eexp == 0.0:
return Const.make(0.0)
elif eexp == 1.0:
return base
return constify(Exp(base, exp))
def is_const(self) -> bool:
return self._base.is_const() and self._exp.is_const()
def __str__(self):
return f'{self._base}^{self._exp}'
if __name__ == '__main__':
x = Var('x')
term0 = Product.make([Const.make(2.0), Exp.make(x, Const.make(3.0))])
term1 = Product.make([Const.make(-3.0), Exp.make(x, Const.make(2.0))])
term2 = Product.make([Const.make(4.0), x])
term3 = Const.make(5.0)
term4 = Product.make([Const.make(5.0), x])
term5 = Product.make([x, x])
poly = Sum.make([term0, term1, term2, term3])
print(poly)
poly2 = Sum.make([term2, term4])
print(poly2)
print(term5)
print(term5.deriv(by='x'))
print(poly.deriv(by='x'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment