Skip to content

Instantly share code, notes, and snippets.

@heffo42
Last active July 25, 2019 02:40
Show Gist options
  • Save heffo42/4bcaa79f096a299fc5854e4e85b79ee1 to your computer and use it in GitHub Desktop.
Save heffo42/4bcaa79f096a299fc5854e4e85b79ee1 to your computer and use it in GitHub Desktop.
class ModelGRU(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.rnn = nn.GRU(nh, nh, 2, batch_first=True)
self.h_o = nn.Linear(nh,nv)
self.bn1 = nn.BatchNorm1d(num_features=nh)
def forward(self, x):
self.h = torch.zeros(2, x.shape[0], nh).cuda()
res,h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
res = res.permute(0,2,1)
res = self.bn1(res)
return self.h_o(res.permute(0,2,1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment