-
-
Save eleganceinsimplicity/d00243d7882a94b1b47eac6a605502be to your computer and use it in GitHub Desktop.
autoregressivelstm_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
# Based off the examples in the book - Inside Deep Learning by Edward Raff | |
class AutoRegressiveLSTM(nn.Module): | |
def __init__(self, num_embeddings, embd_size, hidden_size, layers=1, device="cpu"): | |
super(AutoRegressiveLSTM, self).__init__() | |
self.hidden_size = hidden_size | |
self.embd = nn.Embedding(num_embeddings, embd_size) | |
self.device = device | |
self.layers = nn.ModuleList([nn.LSTMCell(embd_size, hidden_size)] + | |
[nn.LSTMCell(hidden_size, hidden_size) for i in range(layers - 1)]) | |
self.norms_h = nn.ModuleList([nn.LayerNorm(hidden_size) for i in range(layers)]) | |
self.norms_c = nn.ModuleList([nn.LayerNorm(hidden_size) for i in range(layers)]) | |
self.pred_class = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), # (batch_size, *, h_dims) | |
nn.LayerNorm(hidden_size), # (batch_size, *, h_dims) | |
nn.GELU(), | |
nn.Linear(hidden_size, num_embeddings) # (batch_size, *. h__dims) -> (batch_size, *, VocabSize) | |
) | |
def initHiddenStates(self, B): | |
return [torch.nn.init.xavier_normal_(torch.zeros(B, self.hidden_size, device=self.device)) for _ in range(len(self.layers))] | |
def initCellStates(self, B): | |
return [torch.nn.init.xavier_normal_(torch.zeros(B, self.hidden_size, device=self.device)) for _ in range(len(self.layers))] | |
def step(self, x_in, h_prevs=None, c_prevs=None): | |
""" | |
x_in: the input for this current time step and has shape (B) if the values need | |
to be embedded, and (B, D) if they have alreayd been embedded. | |
h_prevs: a list of hidden state tensors each with shape (B, self.hidden_size) for each | |
layer in the network. These contain the current hidden state of the RNN layers and | |
will be updated by this call. | |
c_prevs: a list of cell state tensors each with shape (B, self.hidden_size) for each | |
layer in the network. These contain the current cell state of the RNN layers and | |
will be updated by this call. | |
""" | |
if len(x_in.shape) == 1: # (batch_size), we need to embed it | |
x_in = self.embd(x_in) # now (batch_size, embd_dims) | |
if h_prevs is None: | |
h_prevs = self.initHiddenStates(x_in.shape[0]) | |
if c_prevs is None: | |
c_prevs = self.initCellStates(x_in.shape[0]) | |
# Process the input | |
for l in range(len(self.layers)): | |
h_prev = h_prevs[l] | |
c_prev = c_prevs[l] | |
h, c = self.layers[l](x_in, (h_prev, c_prev)) | |
h = self.norms_h[l](h) | |
c = self.norms_c[l](c) | |
h_prevs[l] = h | |
c_prevs[l] = c | |
x_in = h # After looping over the layers, it will have the last hidden state of the Lstm Cell | |
# Make predictions about the token | |
return self.pred_class(x_in) | |
def forward(self, input): | |
batch_size = input.size(0) | |
time_steps = input.size(1) | |
x = self.embd(input) | |
# Initialize hidden states | |
h_prevs = self.initHiddenStates(batch_size) | |
c_prevs = self.initCellStates(batch_size) | |
last_activations = [] | |
for t in range(time_steps): | |
x_in = x[:, t, :] # (batch_size, embd_dims) | |
last_activations.append(self.step(x_in, h_prevs, c_prevs)) | |
last_activations = torch.stack(last_activations, dim=1) # (batch_size, time_steps, Vocabsize) | |
return last_activations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment