Created
March 16, 2021 18:29
-
-
Save jacKlinc/39d4281bc9e4ff7e6882c70d9ba93f55 to your computer and use it in GitHub Desktop.
Multilayer RNN models. First uses built-in RNN class. Second implements LSTM to solve exploding gradients.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LanguageModelMulti(Module): | |
""" | |
Deepening the model with the built-in RNN class for more accuracy | |
""" | |
def __init__(self, vocab_sz, n_hidden, n_layers): | |
self.i_h = nn.Embedding(vocab_sz, n_hidden) | |
# Creates an RNN within | |
self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True) | |
self.h_o = nn.Linear(n_hidden, vocab_sz) | |
# Creates zeros for all layers | |
self.h = torch.zeros(n_layers, bs, n_hidden) | |
def forward(self, x): | |
res, h = self.rnn(self.i_h(x), self.h) | |
self.h = h.detach() | |
return self.h_o(res) | |
def reset(self): | |
self.h.zero_() # applies zero to gradient | |
class LanguageModelMultiLSTM(Module): | |
""" | |
Implements LSTM architecture. | |
Solves the issue of exploding/vanishing gradients seen in multilayer language models. | |
""" | |
def __init__(self, ni, nh): | |
self.forget_g = nn.Linear(ni + nh, nh) | |
self.input_g = nn.Linear(ni + nh, nh) | |
self.cell_g = nn.Linear(ni + nh, nh) | |
self.output_g = nn.Linear(ni + nh, nh) | |
def forward(self, input, state): | |
h,c = state | |
h = torch.stack([h, input], dim=1) | |
forget = torch.sigmoid(self.forget_g(h)) # Forget gate gets the sigmoid of the hidden state | |
c *= forget # multiply forget and cell state | |
inp = torch.sigmoid(self.input_g(h)) | |
cell = torch.tanh(self.cell_g(h)) # scales from -1 to 1 | |
c += inp * cell # add to cell state | |
out = torch.sigmoid(self.output_g(h)) # sigmoid of hidden | |
h = out * torch.tanh(c) | |
return h, (h, c) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment