Skip to content

Instantly share code, notes, and snippets.

@jacKlinc
Created March 16, 2021 18:29
Show Gist options
  • Save jacKlinc/39d4281bc9e4ff7e6882c70d9ba93f55 to your computer and use it in GitHub Desktop.
Save jacKlinc/39d4281bc9e4ff7e6882c70d9ba93f55 to your computer and use it in GitHub Desktop.
Multilayer RNN models. First uses built-in RNN class. Second implements LSTM to solve exploding gradients.
class LanguageModelMulti(Module):
"""
Deepening the model with the built-in RNN class for more accuracy
"""
def __init__(self, vocab_sz, n_hidden, n_layers):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
# Creates an RNN within
self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)
self.h_o = nn.Linear(n_hidden, vocab_sz)
# Creates zeros for all layers
self.h = torch.zeros(n_layers, bs, n_hidden)
def forward(self, x):
res, h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
return self.h_o(res)
def reset(self):
self.h.zero_() # applies zero to gradient
class LanguageModelMultiLSTM(Module):
"""
Implements LSTM architecture.
Solves the issue of exploding/vanishing gradients seen in multilayer language models.
"""
def __init__(self, ni, nh):
self.forget_g = nn.Linear(ni + nh, nh)
self.input_g = nn.Linear(ni + nh, nh)
self.cell_g = nn.Linear(ni + nh, nh)
self.output_g = nn.Linear(ni + nh, nh)
def forward(self, input, state):
h,c = state
h = torch.stack([h, input], dim=1)
forget = torch.sigmoid(self.forget_g(h)) # Forget gate gets the sigmoid of the hidden state
c *= forget # multiply forget and cell state
inp = torch.sigmoid(self.input_g(h))
cell = torch.tanh(self.cell_g(h)) # scales from -1 to 1
c += inp * cell # add to cell state
out = torch.sigmoid(self.output_g(h)) # sigmoid of hidden
h = out * torch.tanh(c)
return h, (h, c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment