Skip to content

Instantly share code, notes, and snippets.

@wangg12
Last active December 18, 2022 20:54
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save wangg12/f11258583ffcc4728eb71adc0f38e832 to your computer and use it in GitHub Desktop.
Save wangg12/f11258583ffcc4728eb71adc0f38e832 to your computer and use it in GitHub Desktop.
from graphviz import Digraph
from torch.autograd import Variable
import torch
def make_dot(var, params=None):
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style="filled", shape="box", align="left", fontsize="12", ranksep="0.1", height="0.2")
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return "(" + (", ").join(["%d" % v for v in size]) + ")"
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange")
dot.edge(str(id(var.grad_fn)), str(id(var)))
var = var.grad_fn
if hasattr(var, "variable"):
u = var.variable
name = param_map[id(u)] if params is not None else ""
node_name = "%s\n %s" % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor="lightblue")
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, "next_functions"):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, "saved_tensors"):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var)
return dot
if __name__ == "__main__":
import torchvision.models as models
inputs = torch.randn(1, 3, 224, 224)
resnet18 = models.resnet18()
y = resnet18(inputs)
# print(y)
g = make_dot(y)
g.view()
@wangg12
Copy link
Author

wangg12 commented Jan 20, 2021

Thanks. I updated the script according to yours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment