Skip to content

Instantly share code, notes, and snippets.

@OniDaito
Created May 14, 2020 16:41
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Given a model from pytorch, print out to console and graphviz to see what is going on
# Our drawing graph functions. We rely / have borrowed from the following
# python libraries:
# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
# https://github.com/willmcgugan/rich
# https://graphviz.readthedocs.io/en/stable/
def draw_graph(start, watch=[]):
from graphviz import Digraph
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
graph = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
assert(hasattr(start, "grad_fn"))
if start.grad_fn is not None:
_draw_graph(loss.grad_fn, graph, watch=watching)
size_per_element = 0.15
min_size = 12
# Get the approximate number of nodes and edges
num_rows = len(graph.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
graph.graph_attr.update(size=size_str)
graph.render(filename='net_graph.jpg')
def _draw_graph(var, graph, watch=[], seen=[], indent="", pobj=None):
''' recursive function going through the hierarchical graph printing off
what we need to see what autograd is doing.'''
from rich import print
if hasattr(var, "next_functions"):
for fun in var.next_functions:
joy = fun[0]
if joy is not None:
if joy not in seen:
label = str(type(joy)).replace(
"class", "").replace("'", "").replace(" ", "")
label_graph = label
colour_graph = ""
seen.append(joy)
if hasattr(joy, 'variable'):
happy = joy.variable
if happy.is_leaf:
label += " \U0001F343"
colour_graph = "green"
for (name, obj) in watch:
if obj is happy:
label += " \U000023E9 " + \
"[b][u][color=#FF00FF]" + name + \
"[/color][/u][/b]"
label_graph += name
colour_graph = "blue"
break
vv = [str(obj.shape[x])
for x in range(len(obj.shape))]
label += " [["
label += ', '.join(vv)
label += "]]"
label += " " + str(happy.var())
graph.node(str(joy), label_graph, fillcolor=colour_graph)
print(indent + label)
_draw_graph(joy, graph, watch, seen, indent + ".", joy)
if pobj is not None:
graph.edge(str(pobj), str(joy))
@HarshVardhanKumar
Copy link

How do I use it? Where should I input my model to?

@OniDaito
Copy link
Author

OniDaito commented Feb 15, 2021

So 'start' is where you want to begin printing from and for me, I use the loss

Here is a quick example from my model:

  output = model.forward(target, loaded_points,
                           offset_stack,  stretch_axis)
    output = output.reshape(1, 1, 128, 128)
    output = output.to(device)
    loss = F.l1_loss(output, target)
    loss.backward(create_graph=True)

  watching = [("Points", loaded_points), ("fc2", model.fc2.weight),
                ("fc1", model.fc1.weight),
                ("conv1", model.conv1.weight),
                ("conv2", model.conv2.weight),
                ("conv3", model.conv3.weight),
                ("conv4", model.conv4.weight),
                ("conv5", model.conv5.weight),
                ("conv6", model.conv6.weight)]

   watching.append(("stretch_axis", stretch_axis))
   watching.append(("Offsets", offset_stack))
   draw_graph(loss, watching)

So I setup a few things I want to watch specifically but for the most part, use the loss as the starting point
Cheers
B

@seyeeet
Copy link

seyeeet commented Sep 30, 2021

would it be possible to do it for all the elements in the models instead of defining the watching list?

@cocoaaa
Copy link

cocoaaa commented Jan 8, 2022

Thank you for sharing the code. However, could you check if the posted code is correct?
I'm encountering errors when I run it, e.g. line 67 "obj" is undefined -- probably an indentation mistake.

@lashmore
Copy link

@OniDaito - How do I read the output of this? I had the SAME issue with using .view(), but I want to prove to myself that I "broke the computation graph", but when I switch between using movedim() vs. view(), I don't see the difference in the printed output of your scripts:

This is the output when I use .view(), which results in a model that doesn't learn:

<SqueezeBackward1>
.<AddmmBackward>
..<AccumulateGrad> πŸƒ [[1, 250]] tensor(nan, grad_fn=<VarBackward0>)
..<ReluBackward0>
...<MulBackward0>
....<AddmmBackward>
.....<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0015, grad_fn=<VarBackward0>)
.....<ViewBackward>
......<SqueezeBackward1>
.......<AdaptiveMaxPool2DBackward>
........<UnsqueezeBackward0>
.........<ReluBackward0>
..........<SqueezeBackward1>
...........<MkldnnConvolutionBackward>
............<UnsqueezeBackward0>
.............<ViewBackward>
..............<MulBackward0>
...............<EmbeddingBackward>
................<AccumulateGrad> πŸƒ ⏩ embedding [[5000, 50]] tensor(1.0022, grad_fn=<VarBackward0>)
............<UnsqueezeBackward0>
.............<AccumulateGrad> πŸƒ ⏩ conv1 [[250, 50, 3]] tensor(0.0022, grad_fn=<VarBackward0>)
............<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0020, grad_fn=<VarBackward0>)
.....<TBackward>
......<AccumulateGrad> πŸƒ ⏩ linear1 [[250, 250]] tensor(0.0013, grad_fn=<VarBackward0>)
..<TBackward>
...<AccumulateGrad> πŸƒ ⏩ output [[1, 250]] tensor(0.0012, grad_fn=<VarBackward0>)

Would LOVE your help / to understand this visualization!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment