Inspired by the gist Automatic Differentiation in 38 lines of Haskell, (Hacker News link). However, unlike that gist, we are doing reverse-mode autodiff here; the method used by Pytorch, TensorFlow, etc.
The implementation here is a more concise version of the implementation in this blog post. The same concept, but operating on Numpy/CuPy arrays can be found in this repo (with a few neural network examples).
# Test an arbitrary function composed of the above functions:
a, b, c = Var(2), Var(5), Var(-3)
y = relu(log(sin(a * b) - exp(c) + Var(5)) ** a / b * cos(c) + sigmoid(a))
y.backward()
print(f"{y.val = }")
print(f"{a.grad = }")
print(f"{b.grad = }")
print(f"{c.grad = }")
y.val = 0.44533482351666853
a.grad = 0.4925564216709931
b.grad = 0.3107593846787471
c.grad = 0.06870937895756869
Check result using jax's autodiff.
import jax
import jax.numpy as jnp
def f(a, b, c):
return jax.nn.relu(jnp.log(jnp.sin(a * b) - jnp.exp(c) + 5) ** a / b * jnp.cos(c) + jax.nn.sigmoid(a))
grad_f = jax.value_and_grad(f, argnums=(0, 1, 2))
y_val, (a_grad, b_grad, c_grad) = grad_f(float(a.val), float(b.val), float(c.val))
print(f"{y_val = }")
print(f"{a_grad = }")
print(f"{b_grad = }")
print(f"{c_grad = }")
y_val = DeviceArray(0.44533476, dtype=float32, weak_type=True)
a_grad = DeviceArray(0.49255645, dtype=float32, weak_type=True)
b_grad = DeviceArray(0.31075937, dtype=float32, weak_type=True)
c_grad = DeviceArray(0.06870938, dtype=float32, weak_type=True)