Skip to content

Instantly share code, notes, and snippets.

@OniDaito
Created May 14, 2020 16:41
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Given a model from pytorch, print out to console and graphviz to see what is going on
# Our drawing graph functions. We rely / have borrowed from the following
# python libraries:
# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
# https://github.com/willmcgugan/rich
# https://graphviz.readthedocs.io/en/stable/
def draw_graph(start, watch=[]):
from graphviz import Digraph
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
graph = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
assert(hasattr(start, "grad_fn"))
if start.grad_fn is not None:
_draw_graph(loss.grad_fn, graph, watch=watching)
size_per_element = 0.15
min_size = 12
# Get the approximate number of nodes and edges
num_rows = len(graph.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
graph.graph_attr.update(size=size_str)
graph.render(filename='net_graph.jpg')
def _draw_graph(var, graph, watch=[], seen=[], indent="", pobj=None):
''' recursive function going through the hierarchical graph printing off
what we need to see what autograd is doing.'''
from rich import print
if hasattr(var, "next_functions"):
for fun in var.next_functions:
joy = fun[0]
if joy is not None:
if joy not in seen:
label = str(type(joy)).replace(
"class", "").replace("'", "").replace(" ", "")
label_graph = label
colour_graph = ""
seen.append(joy)
if hasattr(joy, 'variable'):
happy = joy.variable
if happy.is_leaf:
label += " \U0001F343"
colour_graph = "green"
for (name, obj) in watch:
if obj is happy:
label += " \U000023E9 " + \
"[b][u][color=#FF00FF]" + name + \
"[/color][/u][/b]"
label_graph += name
colour_graph = "blue"
break
vv = [str(obj.shape[x])
for x in range(len(obj.shape))]
label += " [["
label += ', '.join(vv)
label += "]]"
label += " " + str(happy.var())
graph.node(str(joy), label_graph, fillcolor=colour_graph)
print(indent + label)
_draw_graph(joy, graph, watch, seen, indent + ".", joy)
if pobj is not None:
graph.edge(str(pobj), str(joy))
@seyeeet
Copy link

seyeeet commented Sep 30, 2021

would it be possible to do it for all the elements in the models instead of defining the watching list?

@cocoaaa
Copy link

cocoaaa commented Jan 8, 2022

Thank you for sharing the code. However, could you check if the posted code is correct?
I'm encountering errors when I run it, e.g. line 67 "obj" is undefined -- probably an indentation mistake.

@lashmore
Copy link

@OniDaito - How do I read the output of this? I had the SAME issue with using .view(), but I want to prove to myself that I "broke the computation graph", but when I switch between using movedim() vs. view(), I don't see the difference in the printed output of your scripts:

This is the output when I use .view(), which results in a model that doesn't learn:

<SqueezeBackward1>
.<AddmmBackward>
..<AccumulateGrad> πŸƒ [[1, 250]] tensor(nan, grad_fn=<VarBackward0>)
..<ReluBackward0>
...<MulBackward0>
....<AddmmBackward>
.....<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0015, grad_fn=<VarBackward0>)
.....<ViewBackward>
......<SqueezeBackward1>
.......<AdaptiveMaxPool2DBackward>
........<UnsqueezeBackward0>
.........<ReluBackward0>
..........<SqueezeBackward1>
...........<MkldnnConvolutionBackward>
............<UnsqueezeBackward0>
.............<ViewBackward>
..............<MulBackward0>
...............<EmbeddingBackward>
................<AccumulateGrad> πŸƒ ⏩ embedding [[5000, 50]] tensor(1.0022, grad_fn=<VarBackward0>)
............<UnsqueezeBackward0>
.............<AccumulateGrad> πŸƒ ⏩ conv1 [[250, 50, 3]] tensor(0.0022, grad_fn=<VarBackward0>)
............<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0020, grad_fn=<VarBackward0>)
.....<TBackward>
......<AccumulateGrad> πŸƒ ⏩ linear1 [[250, 250]] tensor(0.0013, grad_fn=<VarBackward0>)
..<TBackward>
...<AccumulateGrad> πŸƒ ⏩ output [[1, 250]] tensor(0.0012, grad_fn=<VarBackward0>)

Would LOVE your help / to understand this visualization!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment