Skip to content

Instantly share code, notes, and snippets.

@lantiga
Created September 11, 2017 22:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lantiga/15ba60f6dbdbc99873f0af94761e9630 to your computer and use it in GitHub Desktop.
Save lantiga/15ba60f6dbdbc99873f0af94761e9630 to your computer and use it in GitHub Desktop.
PyTorch namespaces tests
import torch
import torch.nn as nn
from torch.autograd import Variable
from graphviz import Digraph
def name(node, annotation=None):
kind = node.kind()
if kind in ['PythonOp', 'CppOp']:
return node.blockName() + node.name()
ann = annotation or node.kind()
try:
if node.kind() == 'Return':
content = ','.join([str(n.type()) for n in node.inputs()])
else:
content = str(node.type())
node_name = '%s:%s' % (ann, content)
except:
node_name = ann
return node.blockName() + node_name
def searchOpsRecFwd(ops, visited, nodes):
for i in nodes:
if i in visited:
continue
visited.add(i)
if i.kind() in ['PythonOp', 'CppOp']:
ops.add(i)
elif i.kind() == 'Select':
searchOpsRecFwd(ops, visited, [u.user for u in i.uses()])
else:
ops.add(i)
def searchDownstreamOps(node):
ops = set()
visited = set()
searchOpsRecFwd(ops, visited, [u.user for u in node.uses()])
return ops
def addNode(dot, node, node_name=''):
dot.node(str(id(node)), name(node, node_name))
def addEdge(dot, n1, n2):
dot.edge(str(id(n1)), str(id(n2)))
def make_dot(g, input_names, show_params=False):
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
ops = []
for node in g.nodes():
if node.kind() not in ['PythonOp', 'CppOp']:
continue
dot.node(str(id(node)), name(node))
ops.append(node)
input_dict = dict(zip(g.inputs()[-len(input_names):], input_names))
for node in ops:
downstream = searchDownstreamOps(node)
for n in downstream:
addNode(dot, n)
addEdge(dot, node, n)
for n in node.inputs():
if n in input_dict.keys():
addNode(dot, n, 'Input$' + input_dict[n])
addEdge(dot, n, node)
elif show_params:
if n.kind() != 'Select':
addNode(dot, n)
addEdge(dot, n, node)
elif n.kind() == 'Constant':
addNode(dot, n)
addEdge(dot, n, node)
return dot
def test1():
class MyModule(nn.Module):
def forward(self, x):
t = x + 1
r = t * 2
return r
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.module1 = MyModule()
self.module2 = MyModule()
def forward(self, x):
y = self.module1(x)
torch._tracing_state.push_block('Foo')
t = y + x
torch._tracing_state.pop_block()
return self.module2(t)
def doit(x):
a = MyNet()
return a(x)
t = Variable(torch.ones(1), requires_grad=True)
traced, _ = torch.jit.record_trace(doit, t)
g = torch._C._jit_get_graph(traced)
print(g)
d = make_dot(g, ['t'], show_params=True)
d.view()
def test2():
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Sequential(nn.Linear(2,2), nn.ReLU())
def forward(self, x):
return self.layer1(x)
net = Net()
t = Variable(torch.ones(2), requires_grad=True)
traced, _ = torch.jit.record_trace(net, t)
g = torch._C._jit_get_graph(traced)
print(g)
d = make_dot(g, ['t'], show_params=True)
d.view()
def test3():
from torchvision import models
inputs = Variable(torch.randn(1,3,224,224))
resnet18 = models.resnet18()
traced, _ = torch.jit.record_trace(resnet18, inputs)
g = torch._C._jit_get_graph(traced)
print(g)
d = make_dot(g, ['inputs'])
d.view()
def test4():
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
torch._tracing_state.push_block('Foo')
z = Variable(torch.Tensor([0.7]), requires_grad=True)
out = torch.sigmoid(torch.tanh(x * (y + z)))
torch._tracing_state.pop_block()
return out
traced, _ = torch.jit.record_trace(doit, x, y)
g = torch._C._jit_get_graph(traced)
print(g)
d = make_dot(g, ['x', 'y'])
d.view()
test1()
test2()
test3()
test4()
@lantiga
Copy link
Author

lantiga commented Sep 11, 2017

Test1

test1

Test2

test2

Test3

test3

Test4

test4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment