Skip to content

Instantly share code, notes, and snippets.

@jacKlinc
Created March 16, 2021 07:46
Show Gist options
  • Save jacKlinc/20278d5cf7ab486cd8c7f5c7259bdcf4 to your computer and use it in GitHub Desktop.
Save jacKlinc/20278d5cf7ab486cd8c7f5c7259bdcf4 to your computer and use it in GitHub Desktop.
The first model resets the state, while the second improves on this by introducing more signal through increasing the sequence length.
class LanguageModelRecurrentState(Module):
"""
State is saved by moving the reset to the init method
Gradients are detached for all but 3 layers
"""
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.h_h = nn.Linear(n_hidden, n_hidden)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = 0
def forward(self, x):
for i in range(3):
self.h += self.i_h(x[:, i])
self.h = F.relu(self.h_h(self.h))
out = self.h_o(self.h) # gradients saved for output
self.h = self.h.detach() # remaining gradients dropped
return out
def reset(self):
self.h = 0
class LanguageModelRecurrentStateSignal(Module):
"""
Instead of using the standard three sequence length, one is passed so that
the next word may be predicted instead of the third
"""
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.h_h = nn.Linear(n_hidden, n_hidden)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = 0
def forward(self, x):
outs = []
for i in range(sl):
self.h = self.h + self.i_h(x[:,i])
self.h = F.relu(self.h_h(self.h))
outs.append(self.h_o(self.h))
self.h = self.h.detach()
return torch.stack(outs, dim=1) # shape: bs x sl x vocab_sz
def reset(self):
self.h = 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment