Skip to content

Instantly share code, notes, and snippets.

@prateekjoshi565
Created July 28, 2020 07:33
Show Gist options
  • Save prateekjoshi565/25a278448d1df9588b57b87fa00d03cf to your computer and use it in GitHub Desktop.
Save prateekjoshi565/25a278448d1df9588b57b87fa00d03cf to your computer and use it in GitHub Desktop.
class WordLSTM(nn.Module):
def __init__(self, n_hidden=256, n_layers=4, drop_prob=0.3, lr=0.001):
super().__init__()
self.drop_prob = drop_prob
self.n_layers = n_layers
self.n_hidden = n_hidden
self.lr = lr
self.emb_layer = nn.Embedding(vocab_size, 200)
## define the LSTM
self.lstm = nn.LSTM(200, n_hidden, n_layers,
dropout=drop_prob, batch_first=True)
## define a dropout layer
self.dropout = nn.Dropout(drop_prob)
## define the fully-connected layer
self.fc = nn.Linear(n_hidden, vocab_size)
def forward(self, x, hidden):
''' Forward pass through the network.
These inputs are x, and the hidden/cell state `hidden`. '''
## pass input through embedding layer
embedded = self.emb_layer(x)
## Get the outputs and the new hidden state from the lstm
lstm_output, hidden = self.lstm(embedded, hidden)
## pass through a dropout layer
out = self.dropout(lstm_output)
#out = out.contiguous().view(-1, self.n_hidden)
out = out.reshape(-1, self.n_hidden)
## put "out" through the fully-connected layer
out = self.fc(out)
# return the final output and the hidden state
return out, hidden
def init_hidden(self, batch_size):
''' initializes hidden state '''
# Create two new tensors with sizes n_layers x batch_size x n_hidden,
# initialized to zero, for hidden state and cell state of LSTM
weight = next(self.parameters()).data
# if GPU is available
if (torch.cuda.is_available()):
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
# if GPU is not available
else:
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
return hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment