-
-
Save djhsu/214ae1048ca8719e35a0f577d95bfd85 to your computer and use it in GitHub Desktop.
simple implementation of autodiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class Var: | |
def __init__(self, value=None, deriv=None, op=None, children=None, pds=None): | |
self.value = value | |
self.deriv = deriv | |
self.op = op | |
self.children = children | |
self.pds = pds | |
self.order = None | |
def __repr__(self): | |
return f'Var({self.value})' | |
def forward(self): | |
if self.order is None: | |
self.order = topological_sort(self) | |
for u in self.order: | |
u.deriv = 0. | |
if u.children is not None: | |
u.value = u.op(*[ c.value for c in u.children ]) | |
def backward(self): | |
for v in reversed(self.order): | |
if v == self: | |
v.deriv = 1. | |
if v.children is not None: | |
local_values = [ c.value for c in v.children ] | |
for c, pd in zip(v.children, v.pds): | |
c.deriv += v.deriv * pd(*local_values) | |
def add(a, b): | |
return Var(op=np.add, children=[a, b], pds=[__add_pd, __add_pd]) | |
def subtract(a, b): | |
return Var(op=np.subtract, children=[a, b], pds=[_add_pd, __subtract_pd_b]) | |
def multiply(a, b): | |
return Var(op=np.multiply, children=[a, b], pds=[__multiply_pd_a, __multiply_pd_b]) | |
def divide(a, b): | |
return Var(op=np.divide, children=[a, b], pds=[__divide_pd_a, __divide_pd_b]) | |
def negative(a): | |
return Var(op=np.negative, children=[a], pds=[__negative_pd]) | |
def square(a): | |
return Var(op=np.square, children=[a], pds=[__square_pd]) | |
def exp(a): | |
return Var(op=np.exp, children=[a], pds=[__exp_pd]) | |
def sin(a): | |
return Var(op=np.sin, children=[a], pds=[__sin_pd]) | |
def cos(a): | |
return Var(op=np.cos, children=[a], pds=[__cos_pd]) | |
def __add_pd(a, b): | |
return 1. | |
def __subtract_pd_b(a, b): | |
return -1. | |
def __multiply_pd_a(a, b): | |
return b | |
def __multiply_pd_b(a, b): | |
return a | |
def __divide_pd_a(a, b): | |
return 1. / b | |
def __divide_pd_b(a, b): | |
return -a / (b * b) | |
def __negative_pd(a): | |
return -1. | |
def __square_pd(a): | |
return 2. * a | |
def __exp_pd(a): | |
return np.exp(a) | |
def __sin_pd(a): | |
return np.cos(a) | |
def __cos_pd(a): | |
return -np.sin(a) | |
def topological_sort(v): | |
visited = set() | |
vertices = [] | |
def explore(v): | |
if v not in visited: | |
visited.add(v) | |
if v.children is not None: | |
for c in v.children: | |
explore(c) | |
vertices.append(v) | |
explore(v) | |
return vertices | |
if __name__ == '__main__': | |
x = Var(value=1) | |
w = Var(value=4) | |
v1 = multiply(x, w) | |
v2 = sin(v1) | |
v3 = add(v1, v2) | |
v4 = square(v2) | |
v5 = exp(v3) | |
v6 = multiply(v4, w) | |
v7 = add(v5, v6) | |
for t in range(30): | |
v7.forward() | |
v7.backward() | |
print(f'w={w.value}, v7={v7.value}, (dv7)/(dw)={w.deriv}') | |
w.value -= 0.1 * w.deriv | |
print(f'w={w.value}, v7={v7.value}, (dv7)/(dw)={w.deriv}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment