Skip to content

Instantly share code, notes, and snippets.

@Nikolaj-K
Last active October 23, 2023 21:31
Show Gist options
  • Save Nikolaj-K/6720fa0428a078e22c80dfdb9f89ff23 to your computer and use it in GitHub Desktop.
Save Nikolaj-K/6720fa0428a078e22c80dfdb9f89ff23 to your computer and use it in GitHub Desktop.
A simple implementation of reverse accumulation (back-propagation)
"""
Impelmentation of reverse accumulation (backpropagation), discussed in
https://youtu.be/BcCk8I6YAqw
Most of this script is exposition code. The reverse accumulation code still runs if you delete everything except the following 5 classes
* ValAndDVal, ExprBaseRevAcc, VarExprRevAcc, PlusExprRevAcc, MultExprRevAcc
validated using the following 2 functions:
* p, run_VarExprRevAcc_gradient_decent_demo
For motivation, see Hotz's "tinygrad" autograd/tensor library:
* https://github.com/tinygrad/tinygrad#quick-example-comparing-to-pytorch
- Itself motivated by micrograd by Karpathy, linked therein
- For pytorch, see also https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
- https://www.google.com/search?q=autograd&rlz=1C5CHFA_enAT878AT878&oq=autograd&gs_lcrp=EgZjaHJvbWUyBggAEEUYOTIHCAEQABiABDIHCAIQABiABDIHCAMQABiABDINCAQQLhjHARjRAxiABDIHCAUQABiABDIHCAYQABiABDIGCAcQRRg90gEIMTI4OWowajeoAgCwAgA&sourceid=chrome&ie=UTF-8
See also
* https://en.wikipedia.org/wiki/Backpropagation
* https://en.wikipedia.org/wiki/Automatic_differentiation
- That article ends with a concise implementation using dual numbers (not covered in this script)
This script:
* Implementation of
- ExprBase, for forward accumulation for computing of polynomial derivatives
- ExprBaseRevAcc, for backwards accumulation for computing of polynomial derivatives (see Wikipedia for a similar example)
* Use of ExprBaseRevAcc for gradient decent exmaple
* Start out with motivational LayerStack class, explaining forwards and backwards motion of data.
- Linear computational model, as opposed to tree structure found in the polynomial expression
"""
def p(d, x, y):
"""
This is some polynomial we're gonna compute the derivative of, at various places
See
https://www.wolframalpha.com/input?i=z+%3D+x+*+%28%28x+%2B+y%29+%2B+2%29+%2B+y+*+y
Mathematica code:
d = 2;
p[x_, y_] := x (x + y + d) + y^2
pos = {2, 3}
p[2, 3] == 2 (2 + 3 + 2) + 3 3 == 23
(D[p[x, y], x] /.{x->2, y->3}) == 9
(D[p[x, y], y] /.{x->2, y->3}) == 8
D[p[x[t], y[t]], t] == (2 + 2 * x[t] + y[t]) x'[t] + (x[t] + 2 y[t]) y'[t] //Simplify
"""
return x * (x + y + d) + y * y
class ExprTorch:
"""
Tensor wrapper enabling notations 'f * g' and 'f + g'.
(Technical note: Is recursively defined, in the sense that it uses PlusExpr(ExprBase))
"""
def __init__(self, tensor):
self.tensor = tensor
def __add__(self, other):
return ExprTorch(self.tensor + other.tensor)
def __mul__(self, other):
return ExprTorch(self.tensor.matmul(other.tensor))
def torch_example():
def torch_scalar(scalar):
matrix_1x1 = [float(scalar)]
import torch # Note: If you don't have pytorch installed, just don't run torch_example
return torch.tensor([matrix_1x1], requires_grad=True)
def torch_p(d, x, y):
return p(ExprTorch(d), ExprTorch(x), ExprTorch(y)).tensor
x = torch_scalar(2)
y = torch_scalar(3)
d = torch_scalar(2)
z = torch_p(d, x, y)
z.backward()
print(f"[torch exmaple] z = {float(z)}")
print(f"[torch exmaple] ∂z/∂x = {float(x.grad)} at pos=(2,3)")
print(f"[torch exmaple] ∂z/∂y = {float(y.grad)} at pos=(2,3)")
print(f"[torch exmaple] ∂z/∂d = {float(d.grad)}")
class Layer:
def __init__(self, layer_name, func, prev_layer):
self.layer_name = layer_name
self.data = None
self.func = func
self.prev_layer = prev_layer
print(f"[{self.layer_name}] Initialized.")
def set_data(self, data):
self.data = data
print(f"[{self.layer_name}] Set data {data}.")
def eval_f(self, x):
self.data = self.func(x)
print(f"[{self.layer_name}] Got input x={x},\tSetting data and returning layer.data = layer.func(x) = {self.data}")
return self.data
def push_down_signal(self, signal):
new_signal = signal + "|" + str(self.data)
# Note: If .forward() is not called before .push_down_signal (i.e. when .data are not set), then the logs will show a lot of 'None'
if self.prev_layer is None:
print(f"[{self.layer_name}] Got signal '{signal}'.\tBut prev_layer is None and so ending on signal '{new_signal}'.")
else:
print(f"[{self.layer_name}] Got signal '{signal}'.\tNow pushing down '{new_signal}' to previous layer '{self.prev_layer.layer_name}'.")
self.prev_layer.push_down_signal(new_signal)
print(f"[{self.layer_name}] Done pushing down.")
class LayerStack:
def __init__(self):
print(f"[LayerStack START] Initlaizing stack.")
self.layer_0 = Layer("layer 0", None, None) # .f of layer 0 will not be used
self.layer_1 = Layer("layer 1", lambda x: x + 4 * 10 ** 1, self.layer_0)
self.layer_2 = Layer("layer 2", lambda x: x + 5 * 10 ** 2, self.layer_1)
self.layer_3 = Layer("layer 3", lambda x: x + 6 * 10 ** 3, self.layer_2)
self.layer_4 = Layer("layer 4", lambda x: x + 7 * 10 ** 4, self.layer_3)
print(f"[LayerStack END] Initialized stack.\n")
def set_layer_0_data(self, data):
print(f"[LayerStack START] Setting stack input_value (layer_0_data) to {data}.")
self.layer_0.set_data(data)
print(f"[LayerStack END] Set stack input_value (layer_0_data) to {data}.\n")
def forward(self):
r0 = self.layer_0.data
print(f"[LayerStack START] Forwarding {r0} through all layers.")
r1 = self.layer_1.eval_f(r0)
r2 = self.layer_2.eval_f(r1)
r3 = self.layer_3.eval_f(r2)
r4 = self.layer_4.eval_f(r3)
print(f"[LayerStack END] Forwarded {r0} through all layers and arrived with {r4}.\n")
return r4
def push_down_signal(self, signal):
print(f"[LayerStack START] Pushing down signal={signal} using last layer {self.layer_4.layer_name}")
self.layer_4.push_down_signal(signal)
print(f"[LayerStack END] Pushed down signal={signal} using last layer {self.layer_4.layer_name}")
def run_stack_demo():
INPUT = 3
ls = LayerStack()
ls.set_layer_0_data(INPUT)
_r = ls.forward()
ls.push_down_signal("foo")
class ValAndDVal:
def __init__(self, val):
self.__val = val
def get_val(self):
return self.__val
def get_dval(self):
return self.__dval
def set_dval(self, dval): # Only in ExprBaseRevAcc is this not just called right after __init__
self.__dval = dval
class ExprBase:
"""
Enable notations 'f * g' and 'f + g'.
(Technical note: Is recursively defined, in the sense that it uses PlusExpr(ExprBase))
"""
def __add__(self, other):
return PlusExpr(self, other)
def __mul__(self, other):
return MultExpr(self, other)
def make_val_and_dval(self, var):
# Note: Both (!!) val and dval will be (recursively) forwarded in this step
assert False # Combined eval and getter function must be implemented
def get_val(self): # Auxiliary getter for printing
VAR = None # Not interested in derivative expression here, so passing auxiliary value
return self.make_val_and_dval(VAR).get_val()
class VarExpr(ExprBase):
def __init__(self, val):
self.__val = val
def make_val_and_dval(self, var): # Evaluate val and first-order derivative (Note: there are no higher ones for primitive variables)
# Note: var == None (or passing any other var that's not a VarExpr) is allowed to access val (when not caring about dval)
res = ValAndDVal(self.__val)
res.set_dval(int(self == var)) # kronecker_delta(self, var)
return res
class PlusExpr(ExprBase): # Note: Does not have any numerical value members
def __init__(self, expr_l, expr_r):
self.expr_l = expr_l
self.expr_r = expr_r
def make_val_and_dval(self, var): # Evaluate and get val and first-order derivative val
l_dl = self.expr_l.make_val_and_dval(var)
r_dr = self.expr_r.make_val_and_dval(var)
res = ValAndDVal(l_dl.get_val() + r_dr.get_val())
res.set_dval(l_dl.get_dval() + r_dr.get_dval()) # Derivative distributes over addition
return res
class MultExpr(ExprBase): # Note: Does not have any numerical value members
def __init__(self, expr_l, expr_r):
self.expr_l = expr_l
self.expr_r = expr_r
def make_val_and_dval(self, var): # Evaluate and get val and first-order derivative val
# Compute instances used several times in expression (Same 3 lines as in PlusExpr.make_val_and_dval)
# See also https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py#L28
l_dl = self.expr_l.make_val_and_dval(var)
r_dr = self.expr_r.make_val_and_dval(var)
l_at_v = l_dl.get_val()
r_at_v = r_dr.get_val()
res = ValAndDVal(l_at_v * r_at_v)
res.set_dval(r_at_v * l_dl.get_dval() + l_at_v * r_dr.get_dval()) # Product rule for multiplication, (a*b)' = (a')*b+a*(b')
return res
class ExprBaseRevAcc:
# Similar to ExprBase above, but primitive subexpressions will accumulate dval (partial derivative) values
def __add__(self, other):
return PlusExprRevAcc(self, other)
def __mul__(self, other):
return MultExprRevAcc(self, other)
def eval(self):
self.forward()
DT_OVER_DT = 1 # ∂t/∂t = kronecker_delta(foo, foo) = 1
self.accumulate_derivative(DT_OVER_DT)
class VarExprRevAcc(ExprBaseRevAcc):
# Note: This var will also accumulate a dval. Compare to previous ExprBase which didn't have any dval member field!
def __init__(self, val):
self.val_and_dval = ValAndDVal(val)
self.val_and_dval.set_dval(0)
def forward(self):
pass # Need no forward, since value was already set in __init__(val)
def get_val(self):
return self.val_and_dval.get_val()
def get_dval(self):
return self.val_and_dval.get_dval()
def accumulate_derivative(self, dval): # Accumulation of derivative values coming in
s = self.val_and_dval.get_dval()
self.val_and_dval.set_dval(s + dval)
def step(self, step_size):
correction = -step_size * self.get_dval() # -grad
corrected_val = self.get_val() + correction
self.val_and_dval = ValAndDVal(corrected_val)
self.val_and_dval.set_dval(0)
class PlusExprRevAcc(ExprBaseRevAcc):
def __init__(self, expr_l, expr_r):
self.expr_l = expr_l
self.expr_r = expr_r
self.val = None # This expr class has a record for its own val, but no dval
def get_val(self):
return self.val
def forward(self): # Also a setter function w.r.t. whatever vals are set in the expression
self.expr_l.forward()
self.expr_r.forward()
self.val = self.expr_l.get_val() + self.expr_r.get_val()
def accumulate_derivative(self, dval):
# Note: Addition '+' not implemented here, as both .expr's push back into the same accumulating variables.
#
# Linearity of derivative:
# e := e1 + e2
# de/dt * dT = de1/dt * (1.0 * dT) + de2/dt * (1.0 * dT)
self.expr_l.accumulate_derivative(1.0 * dval)
self.expr_r.accumulate_derivative(1.0 * dval)
class MultExprRevAcc(ExprBaseRevAcc):
# See comments made about PlusExprRevAcc above. MultExprRevAcc is similar, if slightly more complicated
def __init__(self, expr_l, expr_r):
self.expr_l = expr_l
self.expr_r = expr_r
self.val = None
def get_val(self):
return self.val
def forward(self):
self.expr_l.forward()
self.expr_r.forward()
self.val = self.expr_l.get_val() * self.expr_r.get_val()
def accumulate_derivative(self, dval):
# Note: Makes use for val. So need to evaluate before pushing back!
#
# Product rule (multiplication and expression switcheroo):
# e := e1 * e2
# de/dt * dT = de1/dt * (e2 * dT) + de2/dt * (e1 * dT)
self.expr_l.accumulate_derivative(self.expr_r.get_val() * dval)
self.expr_r.accumulate_derivative(self.expr_l.get_val() * dval)
def run_comparison_demo():
# Example using VarExpr: Finding the dvals of x * ((x + y) + d) + y * y, at (x, y) = (2, 3)
x = VarExpr(2)
y = VarExpr(3)
d = VarExpr(2) # Const.
z = p(d, x, y)
dz_dx = z.make_val_and_dval(x)
dz_dy = z.make_val_and_dval(y)
dz_dd = z.make_val_and_dval(d)
z_VarExpr_val = z.get_val()
print("[VarExpr exmaple] z =", z_VarExpr_val)
print("[VarExpr exmaple] ∂z/∂x =", dz_dx.get_dval())
print("[VarExpr exmaple] ∂z/∂y =", dz_dy.get_dval())
print("[VarExpr exmaple] ∂z/∂d =", dz_dd.get_dval())
print()
# Example using VarExprRevAcc
x = VarExprRevAcc(2)
y = VarExprRevAcc(3)
d = VarExprRevAcc(2)
z = p(d, x, y)
z.eval()
z_VarExprRevAcc_val = z.get_val()
print("[VarExprRevAcc exmaple] z =", z_VarExprRevAcc_val)
print("[VarExprRevAcc exmaple] ∂z/∂x =", x.get_dval())
print("[VarExprRevAcc exmaple] ∂z/∂y =", y.get_dval())
print("[VarExprRevAcc exmaple] ∂z/∂d =", d.get_dval())
print()
assert z_VarExprRevAcc_val == z_VarExpr_val
def run_VarExprRevAcc_gradient_decent_demo():
class Config:
ITERATIONS = 5000
EPS = 1e-3
MU = EPS
x = VarExprRevAcc(-3)
y = VarExprRevAcc(-5)
d = VarExprRevAcc(2) # Const.
for idx in range(Config.ITERATIONS):
z = p(d, x, y)
z.eval()
x.step(Config.MU)
y.step(Config.MU)
if (idx < 1000 and idx % 100 == 0) or idx % 1000 == 0:
print(f"[gradient_decent_demo] idx #{idx}, z = {z.get_val()}")
# Validation
MIN_Z_GT = -4/3 # Minimum of p
assert abs(z.get_val() - MIN_Z_GT) < Config.EPS
if __name__=='__main__':
torch_example()
print()
run_stack_demo()
print()
run_comparison_demo()
print()
run_VarExprRevAcc_gradient_decent_demo()
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment