Created
March 16, 2021 07:46
-
-
Save jacKlinc/20278d5cf7ab486cd8c7f5c7259bdcf4 to your computer and use it in GitHub Desktop.
The first model resets the state, while the second improves on this by introducing more signal through increasing the sequence length.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LanguageModelRecurrentState(Module): | |
""" | |
State is saved by moving the reset to the init method | |
Gradients are detached for all but 3 layers | |
""" | |
def __init__(self, vocab_sz, n_hidden): | |
self.i_h = nn.Embedding(vocab_sz, n_hidden) | |
self.h_h = nn.Linear(n_hidden, n_hidden) | |
self.h_o = nn.Linear(n_hidden, vocab_sz) | |
self.h = 0 | |
def forward(self, x): | |
for i in range(3): | |
self.h += self.i_h(x[:, i]) | |
self.h = F.relu(self.h_h(self.h)) | |
out = self.h_o(self.h) # gradients saved for output | |
self.h = self.h.detach() # remaining gradients dropped | |
return out | |
def reset(self): | |
self.h = 0 | |
class LanguageModelRecurrentStateSignal(Module): | |
""" | |
Instead of using the standard three sequence length, one is passed so that | |
the next word may be predicted instead of the third | |
""" | |
def __init__(self, vocab_sz, n_hidden): | |
self.i_h = nn.Embedding(vocab_sz, n_hidden) | |
self.h_h = nn.Linear(n_hidden, n_hidden) | |
self.h_o = nn.Linear(n_hidden, vocab_sz) | |
self.h = 0 | |
def forward(self, x): | |
outs = [] | |
for i in range(sl): | |
self.h = self.h + self.i_h(x[:,i]) | |
self.h = F.relu(self.h_h(self.h)) | |
outs.append(self.h_o(self.h)) | |
self.h = self.h.detach() | |
return torch.stack(outs, dim=1) # shape: bs x sl x vocab_sz | |
def reset(self): | |
self.h = 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment