Skip to content

Instantly share code, notes, and snippets.

@eleganceinsimplicity
Last active October 22, 2022 15:46
Show Gist options
  • Save eleganceinsimplicity/460ad3060e55376cb768b47dcb691e27 to your computer and use it in GitHub Desktop.
Save eleganceinsimplicity/460ad3060e55376cb768b47dcb691e27 to your computer and use it in GitHub Desktop.
custom_ts_lstm_torchscript_jit_compat_model.py
from typing import List, Optional
import torch
from torch import nn, Tensor
from torch.nn import ModuleList, Sequential
import torch.nn.functional as F
from torch.types import _device
try:
from typing_extensions import Final
except:
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
class AutoRegressiveLSTMJITCompatV8(nn.Module):
_constants__ = ['embd', 'layers', 'norms_h', 'norms_c', 'pred_class']
def __init__(self, num_embeddings: int, embd_size: int, hidden_size:int, layer_num=1,
device: str = "cpu"):
super(AutoRegressiveLSTMJITCompatV8, self).__init__()
print(f"{self.__class__.__name__} - Calling python init")
self.hidden_size: int = hidden_size
self.n_layers: int = layer_num
self.embd_size: int = embd_size
self.num_embeddings: int = num_embeddings
self.device: _device = device
self.embd: nn.Embedding = nn.Embedding(num_embeddings, embd_size)
self.layers: ModuleList = self.initialize_LSTMCell_layers(embd_size, hidden_size, layer_num)
self.norms_h: ModuleList = self.initialize_lstm_hidden_layer_norm(hidden_size, layer_num)
self.norms_c: ModuleList = self.initialize_lstm_context_layer_norm(hidden_size, layer_num)
self.pred_class: Sequential = self.initialize_pred_class_seq_layer(hidden_size, num_embeddings)
print(f" {self.__class__.__name__} - python init - Completed")
@torch.jit.export
def set_device(self, device: str) -> None:
if device is not None:
self.device = device
@torch.jit.export
def initHiddenStates(self, B: int) -> List[Tensor]:
"""
Creates an initial hidden state list for the RNN layers.
B: the batch size for the hidden states.
"""
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
@torch.jit.export
def initCellStates(self, B: int) -> List[Tensor]:
"""
Creates an initial cell state list for the RNN layers.
B: the batch size for the hidden states.
"""
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
def build_init_layers(self, batch_size:int , hidden_size:int, device:_device, layer_num: int):
init_layers = [(torch.zeros(batch_size, hidden_size, device=device)) for _ in range(
layer_num)]
return init_layers
def initialize_LSTMCell_layers(self, embd_size: int, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LSTMCell,
[embd_size, hidden_size],
[hidden_size, hidden_size])
def initialize_lstm_hidden_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
[hidden_size],
[hidden_size])
def initialize_lstm_context_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
[hidden_size],
[hidden_size])
def initialize_pred_class_seq_layer(self, hidden_size: int, num_embeddings: int):
seq_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size), # (B, *, D)
nn.LayerNorm(hidden_size), # (B, *, D)
nn.GELU(),
nn.Linear(hidden_size, num_embeddings) # (B, *. D) -> B(B, *, VocabSize)
)
return seq_layer
def init_stacked_lstm(self, num_layers, layer_type, first_layer_args, other_layer_args):
layers = [layer_type(*first_layer_args)] + [layer_type(*other_layer_args)
for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
@torch.jit.export
def step(self, x_in: Tensor, h_prevs: Optional[List[Tensor]],
c_prevs:Optional[List[Tensor]]) -> Tensor:
if len(x_in.shape) == 1:
x_in = self.embd(x_in)
if h_prevs is None:
h_prevs = self.initHiddenStates(x_in.shape[0])
if c_prevs is None:
c_prevs = self.initCellStates(x_in.shape[0])
for i, (cur_layer, cur_norm_h, cur_norm_c) in enumerate(zip(self.layers, self.norms_h, self.norms_c)):
h_prev = h_prevs[i]
c_prev = c_prevs[i]
h, c = cur_layer(x_in, (h_prev, c_prev))
h = cur_norm_h(h)
c = cur_norm_c(c)
h_prevs[i] = h
c_prevs[i] = c
x_in = h
return self.pred_class(x_in)
@torch.jit.export
def predict_all_steps(self, sampling_with_seed: Tensor, h_prevs: Optional[List[Tensor]], c_prevs:Optional[List[Tensor]],
seed_len: int, temperature: float=0.76 ) -> Tensor:
assert sampling_with_seed is not None and seed_len is not None, f"For the predictions, sample with seed, seed_len is required"
h: Tensor = torch.empty(size=(1, self.num_embeddings))
if h_prevs is None and c_prevs is None:
h_prevs = self.initHiddenStates(1)
c_prevs = self.initCellStates(1)
for i in range(0, seed_len):
assert h_prevs is not None and c_prevs is not None, f"h_prevs or c_prevs should not be none for building h, c layers for seed"
h = self.step(sampling_with_seed[:, i], h_prevs=h_prevs, c_prevs=c_prevs)
for i in range(seed_len, sampling_with_seed.size(1)):
h_max = F.softmax(h/temperature, dim=1)
next_tokens: Tensor = torch.multinomial(h_max,1)
if len(next_tokens.shape) > 1:
next_tokens = next_tokens.squeeze(dim=1)
sampling_with_seed[:, i] = next_tokens
seed_len = seed_len + 1
h = self.step(sampling_with_seed[:,i], h_prevs=h_prevs, c_prevs=c_prevs)
return sampling_with_seed
def forward(self, input: Tensor) -> Tensor:
batch_size = input.size(0)
time_steps = input.size(1)
x = self.embd(input) # (B, T, D)
# Initial hidden states
h_prevs = self.initHiddenStates(batch_size)
c_prevs = self.initCellStates(batch_size)
last_activations = []
for t in range(time_steps):
x_in = x[:, t, :] # (batch_size, embd_dims)
last_activations.append(self.step(x_in, h_prevs, c_prevs))
last_activations = torch.stack(last_activations, dim=1) # (batch_size, time_steps, Vocabsize)
return last_activations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment