Last active
October 22, 2022 15:46
-
-
Save eleganceinsimplicity/460ad3060e55376cb768b47dcb691e27 to your computer and use it in GitHub Desktop.
custom_ts_lstm_torchscript_jit_compat_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import ModuleList, Sequential | |
import torch.nn.functional as F | |
from torch.types import _device | |
try: | |
from typing_extensions import Final | |
except: | |
# If you don't have `typing_extensions` installed, you can use a | |
# polyfill from `torch.jit`. | |
from torch.jit import Final | |
class AutoRegressiveLSTMJITCompatV8(nn.Module): | |
_constants__ = ['embd', 'layers', 'norms_h', 'norms_c', 'pred_class'] | |
def __init__(self, num_embeddings: int, embd_size: int, hidden_size:int, layer_num=1, | |
device: str = "cpu"): | |
super(AutoRegressiveLSTMJITCompatV8, self).__init__() | |
print(f"{self.__class__.__name__} - Calling python init") | |
self.hidden_size: int = hidden_size | |
self.n_layers: int = layer_num | |
self.embd_size: int = embd_size | |
self.num_embeddings: int = num_embeddings | |
self.device: _device = device | |
self.embd: nn.Embedding = nn.Embedding(num_embeddings, embd_size) | |
self.layers: ModuleList = self.initialize_LSTMCell_layers(embd_size, hidden_size, layer_num) | |
self.norms_h: ModuleList = self.initialize_lstm_hidden_layer_norm(hidden_size, layer_num) | |
self.norms_c: ModuleList = self.initialize_lstm_context_layer_norm(hidden_size, layer_num) | |
self.pred_class: Sequential = self.initialize_pred_class_seq_layer(hidden_size, num_embeddings) | |
print(f" {self.__class__.__name__} - python init - Completed") | |
@torch.jit.export | |
def set_device(self, device: str) -> None: | |
if device is not None: | |
self.device = device | |
@torch.jit.export | |
def initHiddenStates(self, B: int) -> List[Tensor]: | |
""" | |
Creates an initial hidden state list for the RNN layers. | |
B: the batch size for the hidden states. | |
""" | |
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers) | |
@torch.jit.export | |
def initCellStates(self, B: int) -> List[Tensor]: | |
""" | |
Creates an initial cell state list for the RNN layers. | |
B: the batch size for the hidden states. | |
""" | |
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers) | |
def build_init_layers(self, batch_size:int , hidden_size:int, device:_device, layer_num: int): | |
init_layers = [(torch.zeros(batch_size, hidden_size, device=device)) for _ in range( | |
layer_num)] | |
return init_layers | |
def initialize_LSTMCell_layers(self, embd_size: int, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, nn.LSTMCell, | |
[embd_size, hidden_size], | |
[hidden_size, hidden_size]) | |
def initialize_lstm_hidden_layer_norm(self, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, nn.LayerNorm, | |
[hidden_size], | |
[hidden_size]) | |
def initialize_lstm_context_layer_norm(self, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, nn.LayerNorm, | |
[hidden_size], | |
[hidden_size]) | |
def initialize_pred_class_seq_layer(self, hidden_size: int, num_embeddings: int): | |
seq_layer = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), # (B, *, D) | |
nn.LayerNorm(hidden_size), # (B, *, D) | |
nn.GELU(), | |
nn.Linear(hidden_size, num_embeddings) # (B, *. D) -> B(B, *, VocabSize) | |
) | |
return seq_layer | |
def init_stacked_lstm(self, num_layers, layer_type, first_layer_args, other_layer_args): | |
layers = [layer_type(*first_layer_args)] + [layer_type(*other_layer_args) | |
for _ in range(num_layers - 1)] | |
return nn.ModuleList(layers) | |
@torch.jit.export | |
def step(self, x_in: Tensor, h_prevs: Optional[List[Tensor]], | |
c_prevs:Optional[List[Tensor]]) -> Tensor: | |
if len(x_in.shape) == 1: | |
x_in = self.embd(x_in) | |
if h_prevs is None: | |
h_prevs = self.initHiddenStates(x_in.shape[0]) | |
if c_prevs is None: | |
c_prevs = self.initCellStates(x_in.shape[0]) | |
for i, (cur_layer, cur_norm_h, cur_norm_c) in enumerate(zip(self.layers, self.norms_h, self.norms_c)): | |
h_prev = h_prevs[i] | |
c_prev = c_prevs[i] | |
h, c = cur_layer(x_in, (h_prev, c_prev)) | |
h = cur_norm_h(h) | |
c = cur_norm_c(c) | |
h_prevs[i] = h | |
c_prevs[i] = c | |
x_in = h | |
return self.pred_class(x_in) | |
@torch.jit.export | |
def predict_all_steps(self, sampling_with_seed: Tensor, h_prevs: Optional[List[Tensor]], c_prevs:Optional[List[Tensor]], | |
seed_len: int, temperature: float=0.76 ) -> Tensor: | |
assert sampling_with_seed is not None and seed_len is not None, f"For the predictions, sample with seed, seed_len is required" | |
h: Tensor = torch.empty(size=(1, self.num_embeddings)) | |
if h_prevs is None and c_prevs is None: | |
h_prevs = self.initHiddenStates(1) | |
c_prevs = self.initCellStates(1) | |
for i in range(0, seed_len): | |
assert h_prevs is not None and c_prevs is not None, f"h_prevs or c_prevs should not be none for building h, c layers for seed" | |
h = self.step(sampling_with_seed[:, i], h_prevs=h_prevs, c_prevs=c_prevs) | |
for i in range(seed_len, sampling_with_seed.size(1)): | |
h_max = F.softmax(h/temperature, dim=1) | |
next_tokens: Tensor = torch.multinomial(h_max,1) | |
if len(next_tokens.shape) > 1: | |
next_tokens = next_tokens.squeeze(dim=1) | |
sampling_with_seed[:, i] = next_tokens | |
seed_len = seed_len + 1 | |
h = self.step(sampling_with_seed[:,i], h_prevs=h_prevs, c_prevs=c_prevs) | |
return sampling_with_seed | |
def forward(self, input: Tensor) -> Tensor: | |
batch_size = input.size(0) | |
time_steps = input.size(1) | |
x = self.embd(input) # (B, T, D) | |
# Initial hidden states | |
h_prevs = self.initHiddenStates(batch_size) | |
c_prevs = self.initCellStates(batch_size) | |
last_activations = [] | |
for t in range(time_steps): | |
x_in = x[:, t, :] # (batch_size, embd_dims) | |
last_activations.append(self.step(x_in, h_prevs, c_prevs)) | |
last_activations = torch.stack(last_activations, dim=1) # (batch_size, time_steps, Vocabsize) | |
return last_activations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment