Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active December 11, 2018 23:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/66e7130b28fcf8dda02ce7d98fe20c5b to your computer and use it in GitHub Desktop.
Save jamesr66a/66e7130b28fcf8dda02ce7d98fe20c5b to your computer and use it in GitHub Desktop.
import torch
class SomeMod(torch.jit.ScriptModule):
@torch.jit.script_method
def _unpack(self):
return torch.zeros(3, 4)
@torch.jit.script_method
def forward(self):
return torch.zeros(3, 4)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.sm = SomeMod()
def forward(self, x):
return x + self.sm()
traced = torch.jit.trace(Mod(), (torch.rand(3, 4),))
print(traced)
print(traced.sm)
print(traced.sm._has_method('_unpack'))
===
TracedModule[Mod](
(sm): TracedModule[SomeMod]()
)
TracedModule[SomeMod]()
False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment