Skip to content

Instantly share code, notes, and snippets.

@lanpa
Created May 4, 2018 09:09
Show Gist options
  • Save lanpa/8e614d157a6123b11ed046a71dea1e60 to your computer and use it in GitHub Desktop.
Save lanpa/8e614d157a6123b11ed046a71dea1e60 to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(6,9)
def forward(self, x):
return self.fc(x.view(-1, 6))
input = (torch.zeros(1, 2, 3),)
model = SimpleModel()
traced, z = torch.jit.get_trace_graph(model, input)
torch.onnx._optimize_trace(traced, False)
nodes = list(traced.graph().nodes())
for node in nodes:
print(node, node.attributeNames())
for attr in node.attributeNames():
print(attr, ':', node[attr])
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x*2 # RuntimeError: VariableType::ID() not implemented
return x # good
input = (torch.zeros(1, 2, 3),)
model = SimpleModel()
traced, z = torch.jit.get_trace_graph(model, input)
torch.onnx._optimize_trace(traced, False)
nodes = list(traced.graph().nodes())
for node in nodes:
print(node, node.attributeNames())
for attr in node.attributeNames():
print(attr, ':', node[attr])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment