Skip to content

Instantly share code, notes, and snippets.

@ttumiel
Created November 23, 2019 07:56
Show Gist options
  • Save ttumiel/338ad8d10ac88bc6c4a462c2e9046ca0 to your computer and use it in GitHub Desktop.
Save ttumiel/338ad8d10ac88bc6c4a462c2e9046ca0 to your computer and use it in GitHub Desktop.
Symbolic Differentiation in Python
"Symbolic differentiator"
# Possible things to work on in interview:
# - Add substitution into expression for real numbers
# - If instance is the exact same symbol, use power instead of multiply. Use __eq__
# - Add a simplify method to simplify the mess that results from ugly expressions
# - Convert an expression from string
# - Add support for a^x
class Symbol():
"""
Class that holds the variable which derivatives can
be calculated with respect to.
"""
def __init__(self, symbol):
assert isinstance(symbol, str)
self.symbol = symbol
def __mul__(self, other):
return Expression(self) * other
def __rmul__(self, other):
return self * other
def __add__(self, other):
return Expression(self) + other
def __sub__(self, other):
return Expression(self) - other
def __pow__(self, exp):
return Expression(self)**exp
def __truediv__(self, other):
return Expression(self) / other
def __neg__(self):
return -Expression(self)
def __repr__(self):
return self.symbol
def __str__(self):
return self.__repr__()
def __rtruediv__(self, other):
return other/Expression(self)
def backward(self):
return "1"
class Expression():
"""
An Expression holds a sequence of Symbols and mathemtical operators on those
symbols. Expressions can be added, multiplied, etc. like Symbols.
"""
def __init__(self, expression, derivative=None):
self.expression = str(expression)
if derivative is not None:
self.derivative = derivative
return
if isinstance(expression, Symbol):
self.derivative = expression.backward
elif isinstance(expression, Expression):
self.derivative = expression.derivative
elif isinstance(expression, (float, int)):
self.derivative = lambda: "0"
else:
raise ValueError(f"Can't create Expression of type {type(other)}")
def __mul__(self, other):
if isinstance(other, Expression):
derivative = lambda: "(" + self.backward() + "*" + other.expression + "+" + self.expression + "*" + other.backward() + ")"
return Expression(other.expression + "*" + self.expression, derivative)
other = self.check_type_and_create_expr(other)
return self * other
def __rmul__(self, other):
return self * other
def __add__(self, other):
if isinstance(other, Expression):
derivative = lambda: self.backward() + "+" + other.backward()
return Expression(self.expression + "+" + other.expression, derivative)
other = self.check_type_and_create_expr(other)
return self + other
def __sub__(self, other):
return self + (-1*other)
def __pow__(self, exp):
assert isinstance(exp, (int, float))
if exp == 1: return self
derivative = lambda: str(exp) + "*" + "(" + (self.expression + "*" + self.backward()) + f")^{exp-1}"
return Expression(self.expression + f"^{exp}", derivative)
def __truediv__(self, other):
return self * (other**-1.)
def __neg__(self):
return -1*self
def __repr__(self):
return self.expression
def __str__(self):
return self.__repr__()
def __rtruediv__(self, other):
return other * self**-1
def check_type_and_create_expr(self, other):
if isinstance(other, (Symbol, float, int)):
return Expression(other)
else:
raise ValueError(f"Can't apply to type {type(other)}")
def backward(self):
return self.derivative()
################# Tests ###################
def test_symbol():
x = Symbol('x')
assert str(x) == "x"
assert str(3*x) == "3*x"
assert str(x*x) == 'x*x'
assert str(-x) == "-1*x"
assert str(3/x) == "3*x^-1"
assert str(x+x) == "x+x"
assert str(x-x) == "x+-1*x"
assert str(x**3) == "x^3"
assert str(x/4) == "0.25*x"
assert x.backward() == "1"
assert (3*x).backward() == "(1*3+x*0)" # 3
assert (x*x).backward() == '(1*x+x*1)' # 2x
assert (-x).backward() == "(1*-1+x*0)" # -1
assert (3/x).backward() == "(-1*(x*1)^-2*3+x^-1*0)" # -3/x^2
assert (x+x).backward() == "1+1" # 2
assert (x-x).backward() == "1+(1*-1+x*0)" # 0
assert (x**3).backward() == "3*(x*1)^2" # 3x^2
assert (x/4).backward() == "(1*0.25+x*0)" # 0.25
def test_expression():
x = Symbol('x')
expr = 2*x
assert str(expr) == "2*x" # 2x
expr2 = expr * x # 2x*x
assert str(expr2) == "x*2*x" # 2x^2
assert str(expr2.backward()) == "((1*2+x*0)*x+2*x*1)" # 4x
expr = x**2 + 2*x - 45 / x
assert str(expr) == "x^2+2*x+-1*45*x^-1" # x^2 + 2x - 45/x
assert expr.backward() == "2*(x*1)^1+(1*2+x*0)+((-1*(x*1)^-2*45+x^-1*0)*-1+45*x^-1*0)" # 2x + 2 + 45/(x^2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment