Skip to content

Instantly share code, notes, and snippets.

View nikihowe's full-sized avatar

Niki Howe nikihowe

View GitHub Profile
@nikihowe
nikihowe / jaxpr_graph.py
Created August 4, 2021 13:52 — forked from manuel-delverme/jaxpr_graph.py
visualizing jaxprs
import jax
from jax import core
from graphviz import Digraph
import itertools
styles = {
'const': dict(style='filled', color='goldenrod1'),
'invar': dict(color='mediumspringgreen', style='filled'),
'outvar': dict(style='filled,dashed', fillcolor='indianred1', color='black'),