Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created November 30, 2018 18:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/b0e12ec2f3ba1dc77546323946928d4d to your computer and use it in GitHub Desktop.
Save jamesr66a/b0e12ec2f3ba1dc77546323946928d4d to your computer and use it in GitHub Desktop.
import torch
class FooMod(torch.nn.Module):
def forward(self, x, tup):
return x + tup[0]
traced = torch.jit.trace(FooMod(), (torch.rand(3, 4), (torch.rand(3, 4), torch.rand(4))))
print(traced.graph)
===
graph(%x : Float(3, 4)
%1 : Tuple) {
%2 : Float(3, 4), %3 : Float(4) = prim::TupleUnpack(%1)
%4 : int = prim::Constant[value=1](), scope: FooMod
%5 : Float(3, 4) = aten::add(%x, %2, %4), scope: FooMod
return (%5);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment