Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created February 3, 2022 03:10
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save mattjj/52914908ac22d9ad57b76b685d19acb8 to your computer and use it in GitHub Desktop.
Save mattjj/52914908ac22d9ad57b76b685d19acb8 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from contextlib import contextmanager
from typing import NamedTuple, Callable, Optional, Any
import numpy as np
Array = Any
class Node(NamedTuple):
vjp: Optional[Callable]
parents: List[Node]
parentless_node = lambda: Node(None, [])
class Tracer(NamedTuple):
level: int
val: Array
node: Node
def primitive(f):
def wrapped(*args):
level = find_top_level(args)
if not level: return f(*args)
tracers = [lift(level, x) for x in args]
return process(level, f, tracers)
return wrapped
sin = primitive(np.sin)
cos = primitive(np.cos)
add = Tracer.__add__ = Tracer.__radd__ = primitive(np.add)
mul = Tracer.__mul__ = Tracer.__rmul__ = primitive(np.multiply)
neg = Tracer.__neg__ = primitive(np.negative)
def find_top_level(args):
return max((x.level for x in args if isinstance(x, Tracer)), default=0)
def lift(level, x):
if isinstance(x, Tracer) and x.level == level:
return x
return Tracer(level=level, val=x, node=parentless_node())
def process(level, prim, tracers):
in_vals, in_nodes = zip(*[(t.val, t.node) for t in tracers])
out_val, prim_vjp = vjp_rules[prim](*in_vals)
out_node = Node(vjp=prim_vjp, parents=in_nodes)
return Tracer(level=level, val=out_val, node=out_node)
def vjp(f, *args):
with new_trace_level() as level:
in_tracers = [Tracer(level=level, val=x, node=parentless_node())
for x in args]
out = f(*in_tracers)
_, out_val, out_node = lift(level, out)
in_nodes = [t.node for t in in_tracers]
f_vjp = lambda g: backward_pass(in_nodes, out_node, g)
return out_val, f_vjp
trace_level = 0
@contextmanager
def new_trace_level():
global trace_level
trace_level += 1
try:
yield trace_level
finally:
trace_level -= 1
def backward_pass(in_nodes, out_node, g):
env = {id(out_node): g}
for node in toposort(out_node):
out_bar = env.pop(id(node))
inputs_bar = node.vjp(out_bar)
for input_bar, parent in zip(inputs_bar, node.parents):
env[id(parent)] = add_grads(env.get(id(parent)), input_bar)
return [env.get(id(node)) for node in in_nodes]
def add_grads(g1, g2):
return g2 if g1 is None else g1 + g2
def toposort(end_node):
return reversed([n for n in _toposort(set(), end_node) if n.parents])
def _toposort(seen, node):
if id(node) not in seen:
seen.add(id(node))
for p in node.parents:
yield from _toposort(seen, p)
yield node
vjp_rules = {}
vjp_rules[np.sin] = lambda x: (sin(x), lambda g: [ cos(x) * g])
vjp_rules[np.cos] = lambda x: (cos(x), lambda g: [-sin(x) * g])
vjp_rules[np.add] = lambda x, y: (x + y, lambda g: [g, g])
vjp_rules[np.multiply] = lambda x, y: (x * y, lambda g: [g * y, x * g])
vjp_rules[np.negative] = lambda x: (-x, lambda g: [-g])
def grad(f):
def f_grad(*args):
_, f_vjp = vjp(f, *args)
return f_vjp(1.)[0]
return f_grad
###
def f(x):
return sin(sin(x)) + x
print(f(3.))
print(grad(f)(3.))
print(grad(grad(f))(3.))
print(grad(lambda x: grad(lambda y: x * y)(1.))(1.))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment