Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created February 8, 2019 20:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/c7035bc369f11d1c6f015624c826649a to your computer and use it in GitHub Desktop.
Save jamesr66a/c7035bc369f11d1c6f015624c826649a to your computer and use it in GitHub Desktop.
import torch
class MyDecisionGate(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if bool(x.sum() > 0) :
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.gate = MyDecisionGate()
def forward(self, x, h):
new_h = torch.tanh(self.gate(x) + h)
return new_h, new_h
examples = (torch.rand(3, 4), torch.rand(3, 4))
traced = torch.jit.trace(MyCell(), examples)
print(traced.graph)
print(traced.code)
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = MyCell()
def forward(self, xs):
h = torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = MyRNNLoop()
traced_rnn_loop = torch.jit.trace(rnn_loop, (torch.rand(5, 3, 4)))
print(traced_rnn_loop.code)
class MyRNNLoop(torch.jit.ScriptModule):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(), examples)
@torch.jit.script_method
def forward(self, xs):
h = torch.zeros(3, 4)
y = h
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
script_loop = MyRNNLoop()
print(script_loop.code)
print(script_loop.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment