Skip to content

Instantly share code, notes, and snippets.

@eleganceinsimplicity
Last active September 17, 2022 20:19
Show Gist options
  • Save eleganceinsimplicity/d00243d7882a94b1b47eac6a605502be to your computer and use it in GitHub Desktop.
Save eleganceinsimplicity/d00243d7882a94b1b47eac6a605502be to your computer and use it in GitHub Desktop.
autoregressivelstm_model.py
import torch
from torch import nn
# Based off the examples in the book - Inside Deep Learning by Edward Raff
class AutoRegressiveLSTM(nn.Module):
def __init__(self, num_embeddings, embd_size, hidden_size, layers=1, device="cpu"):
super(AutoRegressiveLSTM, self).__init__()
self.hidden_size = hidden_size
self.embd = nn.Embedding(num_embeddings, embd_size)
self.device = device
self.layers = nn.ModuleList([nn.LSTMCell(embd_size, hidden_size)] +
[nn.LSTMCell(hidden_size, hidden_size) for i in range(layers - 1)])
self.norms_h = nn.ModuleList([nn.LayerNorm(hidden_size) for i in range(layers)])
self.norms_c = nn.ModuleList([nn.LayerNorm(hidden_size) for i in range(layers)])
self.pred_class = nn.Sequential(
nn.Linear(hidden_size, hidden_size), # (batch_size, *, h_dims)
nn.LayerNorm(hidden_size), # (batch_size, *, h_dims)
nn.GELU(),
nn.Linear(hidden_size, num_embeddings) # (batch_size, *. h__dims) -> (batch_size, *, VocabSize)
)
def initHiddenStates(self, B):
return [torch.nn.init.xavier_normal_(torch.zeros(B, self.hidden_size, device=self.device)) for _ in range(len(self.layers))]
def initCellStates(self, B):
return [torch.nn.init.xavier_normal_(torch.zeros(B, self.hidden_size, device=self.device)) for _ in range(len(self.layers))]
def step(self, x_in, h_prevs=None, c_prevs=None):
"""
x_in: the input for this current time step and has shape (B) if the values need
to be embedded, and (B, D) if they have alreayd been embedded.
h_prevs: a list of hidden state tensors each with shape (B, self.hidden_size) for each
layer in the network. These contain the current hidden state of the RNN layers and
will be updated by this call.
c_prevs: a list of cell state tensors each with shape (B, self.hidden_size) for each
layer in the network. These contain the current cell state of the RNN layers and
will be updated by this call.
"""
if len(x_in.shape) == 1: # (batch_size), we need to embed it
x_in = self.embd(x_in) # now (batch_size, embd_dims)
if h_prevs is None:
h_prevs = self.initHiddenStates(x_in.shape[0])
if c_prevs is None:
c_prevs = self.initCellStates(x_in.shape[0])
# Process the input
for l in range(len(self.layers)):
h_prev = h_prevs[l]
c_prev = c_prevs[l]
h, c = self.layers[l](x_in, (h_prev, c_prev))
h = self.norms_h[l](h)
c = self.norms_c[l](c)
h_prevs[l] = h
c_prevs[l] = c
x_in = h # After looping over the layers, it will have the last hidden state of the Lstm Cell
# Make predictions about the token
return self.pred_class(x_in)
def forward(self, input):
batch_size = input.size(0)
time_steps = input.size(1)
x = self.embd(input)
# Initialize hidden states
h_prevs = self.initHiddenStates(batch_size)
c_prevs = self.initCellStates(batch_size)
last_activations = []
for t in range(time_steps):
x_in = x[:, t, :] # (batch_size, embd_dims)
last_activations.append(self.step(x_in, h_prevs, c_prevs))
last_activations = torch.stack(last_activations, dim=1) # (batch_size, time_steps, Vocabsize)
return last_activations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment