PyTorch graph visualization
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable, Function
from collections import defaultdict
import graphviz
This is a rather distorted implementation of graph visualization in PyTorch.
This implementation is distorted because PyTorch's autograd is undergoing refactoring right now.
- neither func.next_functions nor func.previous_functions can be relied upon
- BatchNorm's C backend does not follow the python Function interface
- I'm not even sure whether to use var.creator or var.grad_fn (apparently the source tree and wheel builds use different
interface now)
As a result, we are forced to manually trace the graph, using 2 redundant mechanisms:
- Function.__call__: this allows us to trace all Function creations. Function corresponds to Op in TF
- Module.forward_hook: this is needed because the above method doesn't work for BatchNorm, as the current C backend does
not follow the Python Function interface.
To do graph visualization, follow these steps:
1. register hooks on model: register_vis_hooks(model)
2. pass data through model: output = model(input)
3. remove hooks : remove_vis_hooks()
4. perform visualization : save_visualization(name, format='svg') # name is a string without extension
old_function__call__ = Function.__call__
def register_creator(inputs, creator, output):
In the forward pass, our Function.__call__ and BatchNorm.forward_hook both call this method to register the creators
inputs: list of input variables
creator: one of
- Function
- BatchNorm module
output: a single output variable
cid = id(creator)
oid = id(output)
if oid in vars:
# connect creator to input
for input in inputs:
iid = id(input)
func_trace[cid][iid] = input
# register input
vars[iid] = input
# connect output to creator
assert type(output) not in [tuple, list, dict]
var_trace[oid][cid] = creator
# register creator and output and all inputs
vars[oid] = output
funcs[cid] = creator
hooks = []
def register_vis_hooks(model):
global var_trace, func_trace, vars, funcs
var_trace = defaultdict(lambda: {}) # map oid to {cid:creator}
func_trace = defaultdict(lambda: {}) # map cid to {iid:input}
vars = {} # map vid to Variable/Parameter
funcs = {} # map cid to Function/BatchNorm module
hooks = [] # contains the forward hooks, needed for hook removal
def hook_func(module, inputs, output):
assert 'BatchNorm' in mod.__class__.__name__ # batchnorms don't have shared superclass
inputs = list(inputs)
for p in [module.weight, module.bias]:
if p is not None:
register_creator(inputs, module, output)
for mod in model.modules():
if 'BatchNorm' in mod.__class__.__name__: # batchnorms don't have shared superclass
hook = mod.register_forward_hook(hook_func)
def new_function__call__(self, *args, **kwargs):
inputs = [a for a in args if isinstance(a, Variable)]
inputs += [a for a in kwargs.values() if isinstance(a, Variable)]
output = old_function__call__(self, *args, **kwargs)
register_creator(inputs, self, output)
return output
Function.__call__ = new_function__call__
def remove_vis_hooks():
for hook in hooks:
Function.__call__ = old_function__call__
def save_visualization(name, format='svg'):
g = graphviz.Digraph(format=format)
def sizestr(var):
size = [int(i) for i in list(var.size())]
return str(size)
# add variable nodes
for vid, var in vars.iteritems():
if isinstance(var, nn.Parameter):
g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='red')
elif isinstance(var, Variable):
g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='lightblue')
assert False, var.__class__
# add creator nodes
for cid in func_trace:
creator = funcs[cid]
g.node(str(cid), label=str(creator.__class__.__name__), shape='rectangle', style='filled', fillcolor='orange')
# add edges between creator and inputs
for cid in func_trace:
for iid in func_trace[cid]:
g.edge(str(iid), str(cid))
# add edges between outputs and creators
for oid in var_trace:
for cid in var_trace[oid]:
g.edge(str(cid), str(oid))
dandelin commented Sep 15, 2017

This code losts its forward trace when non-Function is called upon Variable (e.g., Variable.view()).
Any Ideas?

