Skip to content

Instantly share code, notes, and snippets.

@heffo42
Created July 19, 2019 21:59
Show Gist options
  • Save heffo42/d26ed4844cdd4eb8886433f2d5c4e071 to your computer and use it in GitHub Desktop.
Save heffo42/d26ed4844cdd4eb8886433f2d5c4e071 to your computer and use it in GitHub Desktop.
class ModelRNN(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.rnn = nn.RNN(nh,nh, batch_first=True)
self.h_o = nn.Linear(nh,nv)
self.bn = nn.BatchNorm1d(num_features=nh)
def forward(self, x):
self.h = torch.zeros(1, x.shape[0], nh).cuda()
res,h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
return self.h_o(self.bn(res.permute(0,2,1)).permute(0,2,1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment