Skip to content

Instantly share code, notes, and snippets.

@qfgaohao
Last active November 24, 2019 14:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qfgaohao/341d92ec29be5d0dac6c946c8793968e to your computer and use it in GitHub Desktop.
Save qfgaohao/341d92ec29be5d0dac6c946c8793968e to your computer and use it in GitHub Desktop.
demonstrate how to trace/parse a pytorch graph
import torch
from torchvision import models
def parse(net, inputs = torch.randn(1, 3, 224, 224)):
with torch.onnx.set_training(net, False):
trace = torch.onnx.utils._trace(net, inputs)
graph = trace.graph()
for n in graph.nodes():
print(n.scopeName(), n.kind())
attrs = str({k: n[k] for k in n.attributeNames()})
print("attrs", attrs)
inputs = [i.uniqueName() for i in n.inputs()]
print("inputs", inputs)
outputs = [i.uniqueName() for i in n.outputs()]
print("outputs", outputs)
print('---------------\n\n')
parse(models.densenet121(pretrained=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment