Skip to content

Instantly share code, notes, and snippets.

Created May 14, 2020 16:41
Show Gist options
  • Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Given a model from pytorch, print out to console and graphviz to see what is going on
# Our drawing graph functions. We rely / have borrowed from the following
# python libraries:
def draw_graph(start, watch=[]):
from graphviz import Digraph
node_attr = dict(style='filled',
graph = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
assert(hasattr(start, "grad_fn"))
if start.grad_fn is not None:
_draw_graph(loss.grad_fn, graph, watch=watching)
size_per_element = 0.15
min_size = 12
# Get the approximate number of nodes and edges
num_rows = len(graph.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
def _draw_graph(var, graph, watch=[], seen=[], indent="", pobj=None):
''' recursive function going through the hierarchical graph printing off
what we need to see what autograd is doing.'''
from rich import print
if hasattr(var, "next_functions"):
for fun in var.next_functions:
joy = fun[0]
if joy is not None:
if joy not in seen:
label = str(type(joy)).replace(
"class", "").replace("'", "").replace(" ", "")
label_graph = label
colour_graph = ""
if hasattr(joy, 'variable'):
happy = joy.variable
if happy.is_leaf:
label += " \U0001F343"
colour_graph = "green"
for (name, obj) in watch:
if obj is happy:
label += " \U000023E9 " + \
"[b][u][color=#FF00FF]" + name + \
label_graph += name
colour_graph = "blue"
vv = [str(obj.shape[x])
for x in range(len(obj.shape))]
label += " [["
label += ', '.join(vv)
label += "]]"
label += " " + str(happy.var())
graph.node(str(joy), label_graph, fillcolor=colour_graph)
print(indent + label)
_draw_graph(joy, graph, watch, seen, indent + ".", joy)
if pobj is not None:
graph.edge(str(pobj), str(joy))
Copy link

@OniDaito - How do I read the output of this? I had the SAME issue with using .view(), but I want to prove to myself that I "broke the computation graph", but when I switch between using movedim() vs. view(), I don't see the difference in the printed output of your scripts:

This is the output when I use .view(), which results in a model that doesn't learn:

..<AccumulateGrad> πŸƒ [[1, 250]] tensor(nan, grad_fn=<VarBackward0>)
.....<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0015, grad_fn=<VarBackward0>)
................<AccumulateGrad> πŸƒ ⏩ embedding [[5000, 50]] tensor(1.0022, grad_fn=<VarBackward0>)
.............<AccumulateGrad> πŸƒ ⏩ conv1 [[250, 50, 3]] tensor(0.0022, grad_fn=<VarBackward0>)
............<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0020, grad_fn=<VarBackward0>)
......<AccumulateGrad> πŸƒ ⏩ linear1 [[250, 250]] tensor(0.0013, grad_fn=<VarBackward0>)
...<AccumulateGrad> πŸƒ ⏩ output [[1, 250]] tensor(0.0012, grad_fn=<VarBackward0>)

Would LOVE your help / to understand this visualization!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment