Skip to content

Instantly share code, notes, and snippets.

@sradc
Last active November 13, 2024 14:11
Show Gist options
  • Save sradc/d9d66e3898ffe3a02e0b6b266629b042 to your computer and use it in GitHub Desktop.
Save sradc/d9d66e3898ffe3a02e0b6b266629b042 to your computer and use it in GitHub Desktop.
Automatic Differentiation in 26 lines of Python
import math
class Var:
def __init__(self, val: float, local_gradients=()):
self.val = val
self.local_gradients = local_gradients
self.grad = 0
def backward(self, path_value: float = 1):
for child_var, local_gradient in self.local_gradients:
child_var.grad += path_value * local_gradient
child_var.backward(path_value * local_gradient)
Var.__add__ = lambda a, b: Var(a.val + b.val, [(a, 1), (b, 1)])
Var.__truediv__ = lambda a, b: Var(a.val / b.val, [(a, 1 / b.val), (b, -a.val/b.val**2)])
Var.__mul__ = lambda a, b: Var(a.val * b.val, [(a, b.val), (b, a.val)])
Var.__neg__ = lambda a: Var(-a.val, [(a, -1)])
Var.__sub__ = lambda a, b: Var(a.val - b.val, [(a, 1), (b, -1)])
Var.__pow__ = lambda a, k: Var(a.val ** k.val, [(a, k.val * a.val ** (k.val - 1)), (k, (a.val ** k.val) * math.log(a.val))])
exp = lambda a: Var(math.exp(a.val), [(a, math.exp(a.val))])
log = lambda a: Var(math.log(a.val), [(a, 1/ a.val)])
sin = lambda a: Var(math.sin(a.val), [(a, math.cos(a.val))])
cos = lambda a: Var(math.cos(a.val), [(a, -math.sin(a.val))])
sig_ = lambda a: math.exp(a) / (math.exp(a) + 1)
sigmoid = lambda a: Var(sig_(a.val), [(a, sig_(a.val) * (1 - sig_(a.val)))])
relu = lambda a: Var(max(a.val, 0), [(a, (a.val > 0) * 1)])

Automatic Differentiation in 26 lines of Python

Inspired by the gist Automatic Differentiation in 38 lines of Haskell, (Hacker News link). However, unlike that gist, we are doing reverse-mode autodiff here; the method used by Pytorch, TensorFlow, etc.

The implementation here is a more concise version of the implementation in this blog post. The same concept, but operating on Numpy/CuPy arrays can be found in this repo (with a few neural network examples).

Example usage:

# Test an arbitrary function composed of the above functions:
a, b, c = Var(2), Var(5), Var(-3)
y = relu(log(sin(a * b) - exp(c) + Var(5)) ** a / b * cos(c) + sigmoid(a))
y.backward()
print(f"{y.val = }")
print(f"{a.grad = }")
print(f"{b.grad = }")
print(f"{c.grad = }")
y.val = 0.44533482351666853
a.grad = 0.4925564216709931
b.grad = 0.3107593846787471
c.grad = 0.06870937895756869

Check result using jax's autodiff.

import jax
import jax.numpy as jnp

def f(a, b, c):
    return jax.nn.relu(jnp.log(jnp.sin(a * b) - jnp.exp(c) + 5) ** a / b * jnp.cos(c) + jax.nn.sigmoid(a))

grad_f = jax.value_and_grad(f, argnums=(0, 1, 2))
y_val, (a_grad, b_grad, c_grad) = grad_f(float(a.val), float(b.val), float(c.val))
print(f"{y_val = }")
print(f"{a_grad = }")
print(f"{b_grad = }")
print(f"{c_grad = }")
y_val = DeviceArray(0.44533476, dtype=float32, weak_type=True)
a_grad = DeviceArray(0.49255645, dtype=float32, weak_type=True)
b_grad = DeviceArray(0.31075937, dtype=float32, weak_type=True)
c_grad = DeviceArray(0.06870938, dtype=float32, weak_type=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment