Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created December 7, 2020 22:07
Show Gist options
  • Save mattjj/a60e0991455965ae960a8d2dcddc3407 to your computer and use it in GitHub Desktop.
Save mattjj/a60e0991455965ae960a8d2dcddc3407 to your computer and use it in GitHub Desktop.
import itertools
import operator as op
from weakref import ref
import numpy as np
from graphviz import Digraph
import jax
from jax import core
from jax import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.util import safe_map as map, safe_zip as zip
from jax.tree_util import tree_flatten, tree_unflatten
from jax.api_util import flatten_fun
styles = {
'const': dict(style='filled', color='goldenrod1'),
'invar': dict(color='mediumspringgreen', style='filled'),
'outvar': dict(style='filled,dashed', fillcolor='indianred1', color='black'),
'op_node': dict(shape='box', color='lightskyblue', style='filled'),
'intermediate': dict(style='filled', color='cornflowerblue')
}
def grad_graph(fun, *args):
_, fun_vjp = jax.vjp(fun, *args)
jaxpr = fun_vjp.args[0].func.args[1]
id_names = (f'id{id}' for id in itertools.count())
graph = Digraph(engine='dot')
graph.attr(size='6,10!')
for v in jaxpr.constvars:
graph.node(str(v), core.raise_to_shaped(v.aval).str_short(), styles['const'])
for v in jaxpr.invars:
graph.node(str(v), v.aval.str_short(), styles['invar'])
for eqn in jaxpr.eqns:
for v in eqn.invars:
if isinstance(v, core.Literal):
graph.node(str(id(v.val)), core.raise_to_shaped(core.get_aval(v.val)).str_short(),
styles['const'])
if eqn.primitive.multiple_results:
id_name = next(id_names)
graph.node(id_name, str(eqn.primitive), styles['op_node'])
for v in eqn.invars:
graph.edge(str(id(v.val) if isinstance(v, core.Literal) else v), id_name)
for v in eqn.outvars:
graph.node(str(v), v.aval.str_short(), styles['intermediate'])
graph.edge(id_name, str(v))
else:
outv, = eqn.outvars
graph.node(str(outv), str(eqn.primitive), styles['op_node'])
for v in eqn.invars:
graph.edge(str(id(v.val) if isinstance(v, core.Literal) else v), str(outv))
for v in jaxpr.outvars:
graph.node(str(v), "out", styles['outvar'])
return graph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment