Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Created April 28, 2020 17:48
Show Gist options
  • Save piEsposito/93014f9159f72e1ccaabc4bcc5b8f37b to your computer and use it in GitHub Desktop.
Save piEsposito/93014f9159f72e1ccaabc4bcc5b8f37b to your computer and use it in GitHub Desktop.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(len(encoder.vocab)+1, 32)
self.lstm = nn.LSTM(32, 32, batch_first=True)
self.fc1 = nn.Linear(32, 2)
def forward(self, x):
x_ = self.embedding(x)
x_, (h_n, c_n) = self.lstm(x_)
x_ = (x_[:, -1, :])
x_ = self.fc1(x_)
return x_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment