Skip to content

Instantly share code, notes, and snippets.

@avinashselvam
Created May 9, 2020 12:32
Show Gist options
  • Save avinashselvam/b098cc2b02ca6664c260ea5c880bbbf9 to your computer and use it in GitHub Desktop.
Save avinashselvam/b098cc2b02ca6664c260ea5c880bbbf9 to your computer and use it in GitHub Desktop.
example of forward and backward automatic differentiation on a computational graph
class Constant():
def __init__(self, value):
self.value = value
self.gradient = None
def evaluate(self):
return self.value
def derivative(self, wrt_variable):
# derivative of a constant w.r.t anything is 0
return 0
def backprop(self, prev_gradient):
# can't differentiate with respect to a constant
pass
class Variable():
count = 0
def __init__(self, value):
Variable.count += 1
self.name = "var"+str(Variable.count)
self.value = value
self.gradient = 0
def evaluate(self):
return self.value
def derivative(self, wrt_variable):
# derivative w.r.t itself is 1 otherwise 0
return 1 if wrt_variable == self else 0
def backprop(self, prev_gradient):
# the variable maybe present in many nodes in the graph
# we add all the contributions
self.gradient += prev_gradient
class BinaryOperator():
def __init__(self, a, b):
self.a = a
self.b = b
self.cache = None
class Add(BinaryOperator):
def evaluate(self):
if not self.cache: self.cache = self.a.evaluate() + self.b.evaluate()
return self.cache
def derivative(self, wrt_variable):
# (f+g)' = f' + g'
return self.a.derivative(wrt_variable) + self.b.derivative(wrt_variable)
def backprop(self, prev_gradient):
self.a.backprop(prev_gradient)
self.b.backprop(prev_gradient)
class Multiply(BinaryOperator):
def evaluate(self):
if not self.cache: self.cache = self.a.evaluate()*self.b.evaluate()
return self.cache
def derivative(self, wrt_variable):
# (uv)' = uv' + u'v
return self.a.derivative(wrt_variable)*self.b.evaluate() + self.a.evaluate()*self.b.derivative(wrt_variable)
def backprop(self, prev_gradient):
self.a.backprop(self.b.evaluate()*prev_gradient)
self.b.backprop(self.a.evaluate()*prev_gradient)
"""
z(x) = v(u(x))
forward diff --> z'(x) = u'(x)*v'(u(x))
backward diff --> z'(x) = v'(u(x))*u'(x)
Let z = x**2 + xy + 2
"""
x = Variable(3.0)
y = Variable(5.0)
graph = Add(Add(Multiply(x, x), Multiply(x, y)), Constant(2))
graph.backprop(1.0)
print(f"for x = 3, y = 5 we get z = {graph.evaluate()}")
print(f"by forward diff dzdx = {graph.derivative(x)}, dzdy = {graph.derivative(y)}")
print(f"by backward diff dzdx = {x.gradient}, dzdy = {y.gradient}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment