Skip to content

Instantly share code, notes, and snippets.

@heffo42
Created July 19, 2019 21:35
Show Gist options
  • Save heffo42/ebb8698d0d791f5c505252bc685a8b0e to your computer and use it in GitHub Desktop.
Save heffo42/ebb8698d0d791f5c505252bc685a8b0e to your computer and use it in GitHub Desktop.
nv = len(vocab)
nh = 64
class ModelRNNBasic(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.h_h = nn.Linear(nh,nh)
self.h_o = nn.Linear(nh,nv)
self.bn = nn.BatchNorm1d(nh)
def forward(self, x):
h = torch.zeros(x.shape[0], nh).cuda()
res = []
for i in range(x.shape[1]):
h = h + self.i_h(x[:,i])
h = F.relu(self.h_h(h))
res.append(self.h_o(self.bn(h)))
return torch.stack(res, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment