Created
November 23, 2019 07:56
-
-
Save ttumiel/338ad8d10ac88bc6c4a462c2e9046ca0 to your computer and use it in GitHub Desktop.
Symbolic Differentiation in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"Symbolic differentiator" | |
# Possible things to work on in interview: | |
# - Add substitution into expression for real numbers | |
# - If instance is the exact same symbol, use power instead of multiply. Use __eq__ | |
# - Add a simplify method to simplify the mess that results from ugly expressions | |
# - Convert an expression from string | |
# - Add support for a^x | |
class Symbol(): | |
""" | |
Class that holds the variable which derivatives can | |
be calculated with respect to. | |
""" | |
def __init__(self, symbol): | |
assert isinstance(symbol, str) | |
self.symbol = symbol | |
def __mul__(self, other): | |
return Expression(self) * other | |
def __rmul__(self, other): | |
return self * other | |
def __add__(self, other): | |
return Expression(self) + other | |
def __sub__(self, other): | |
return Expression(self) - other | |
def __pow__(self, exp): | |
return Expression(self)**exp | |
def __truediv__(self, other): | |
return Expression(self) / other | |
def __neg__(self): | |
return -Expression(self) | |
def __repr__(self): | |
return self.symbol | |
def __str__(self): | |
return self.__repr__() | |
def __rtruediv__(self, other): | |
return other/Expression(self) | |
def backward(self): | |
return "1" | |
class Expression(): | |
""" | |
An Expression holds a sequence of Symbols and mathemtical operators on those | |
symbols. Expressions can be added, multiplied, etc. like Symbols. | |
""" | |
def __init__(self, expression, derivative=None): | |
self.expression = str(expression) | |
if derivative is not None: | |
self.derivative = derivative | |
return | |
if isinstance(expression, Symbol): | |
self.derivative = expression.backward | |
elif isinstance(expression, Expression): | |
self.derivative = expression.derivative | |
elif isinstance(expression, (float, int)): | |
self.derivative = lambda: "0" | |
else: | |
raise ValueError(f"Can't create Expression of type {type(other)}") | |
def __mul__(self, other): | |
if isinstance(other, Expression): | |
derivative = lambda: "(" + self.backward() + "*" + other.expression + "+" + self.expression + "*" + other.backward() + ")" | |
return Expression(other.expression + "*" + self.expression, derivative) | |
other = self.check_type_and_create_expr(other) | |
return self * other | |
def __rmul__(self, other): | |
return self * other | |
def __add__(self, other): | |
if isinstance(other, Expression): | |
derivative = lambda: self.backward() + "+" + other.backward() | |
return Expression(self.expression + "+" + other.expression, derivative) | |
other = self.check_type_and_create_expr(other) | |
return self + other | |
def __sub__(self, other): | |
return self + (-1*other) | |
def __pow__(self, exp): | |
assert isinstance(exp, (int, float)) | |
if exp == 1: return self | |
derivative = lambda: str(exp) + "*" + "(" + (self.expression + "*" + self.backward()) + f")^{exp-1}" | |
return Expression(self.expression + f"^{exp}", derivative) | |
def __truediv__(self, other): | |
return self * (other**-1.) | |
def __neg__(self): | |
return -1*self | |
def __repr__(self): | |
return self.expression | |
def __str__(self): | |
return self.__repr__() | |
def __rtruediv__(self, other): | |
return other * self**-1 | |
def check_type_and_create_expr(self, other): | |
if isinstance(other, (Symbol, float, int)): | |
return Expression(other) | |
else: | |
raise ValueError(f"Can't apply to type {type(other)}") | |
def backward(self): | |
return self.derivative() | |
################# Tests ################### | |
def test_symbol(): | |
x = Symbol('x') | |
assert str(x) == "x" | |
assert str(3*x) == "3*x" | |
assert str(x*x) == 'x*x' | |
assert str(-x) == "-1*x" | |
assert str(3/x) == "3*x^-1" | |
assert str(x+x) == "x+x" | |
assert str(x-x) == "x+-1*x" | |
assert str(x**3) == "x^3" | |
assert str(x/4) == "0.25*x" | |
assert x.backward() == "1" | |
assert (3*x).backward() == "(1*3+x*0)" # 3 | |
assert (x*x).backward() == '(1*x+x*1)' # 2x | |
assert (-x).backward() == "(1*-1+x*0)" # -1 | |
assert (3/x).backward() == "(-1*(x*1)^-2*3+x^-1*0)" # -3/x^2 | |
assert (x+x).backward() == "1+1" # 2 | |
assert (x-x).backward() == "1+(1*-1+x*0)" # 0 | |
assert (x**3).backward() == "3*(x*1)^2" # 3x^2 | |
assert (x/4).backward() == "(1*0.25+x*0)" # 0.25 | |
def test_expression(): | |
x = Symbol('x') | |
expr = 2*x | |
assert str(expr) == "2*x" # 2x | |
expr2 = expr * x # 2x*x | |
assert str(expr2) == "x*2*x" # 2x^2 | |
assert str(expr2.backward()) == "((1*2+x*0)*x+2*x*1)" # 4x | |
expr = x**2 + 2*x - 45 / x | |
assert str(expr) == "x^2+2*x+-1*45*x^-1" # x^2 + 2x - 45/x | |
assert expr.backward() == "2*(x*1)^1+(1*2+x*0)+((-1*(x*1)^-2*45+x^-1*0)*-1+45*x^-1*0)" # 2x + 2 + 45/(x^2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment