Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created June 28, 2018 22:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/1c3133737d274016e7dc38a5e7a54a95 to your computer and use it in GitHub Desktop.
Save jamesr66a/1c3133737d274016e7dc38a5e7a54a95 to your computer and use it in GitHub Desktop.
def test_call_python_fn_from_script_module(self):
def python_fn(x):
return torch.neg(x)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
@torch.jit.script_method
def forward(self, x):
return python_fn(torch.mm(x, self.param))
sm = ScriptMod()
self.assertExpected(str(sm.__getattr__('forward').graph))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment