Created
October 22, 2022 16:14
-
-
Save eleganceinsimplicity/90cb789017272a30073ac11acc70eac7 to your computer and use it in GitHub Desktop.
custom_ts_lstm_torchscript_to_onnx_compat_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Optional | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import ModuleList, Sequential | |
import torch.nn.functional as F | |
try: | |
from typing_extensions import Final | |
except: | |
# If you don't have `typing_extensions` installed, you can use a | |
# polyfill from `torch.jit`. | |
from torch.jit import Final | |
""" | |
Using torch.nn.LSTM module as a single layer, 1 seq length LSTMCell | |
Pytorch -> Torchscript -> ONNX export does not exist for LSTMCell, | |
however, LSTM module is supported | |
https://github.com/onnx/onnx/issues/3597#issuecomment-883835426 | |
""" | |
class AutoRegressiveLSTMCell(nn.Module): | |
_constants__ = ['custom_lstmcell'] | |
def __init__(self, embd_size:int, hidden_size: int, | |
device: str = "cpu"): | |
super(AutoRegressiveLSTMCell, self).__init__() | |
self.layer_num: int = 1 | |
self.seq_len: int = 1 | |
self.embd_size: int = embd_size | |
self.hidden_size:int = hidden_size | |
self.device: str = device | |
self.custom_lstmcell = nn.LSTM(embd_size, hidden_size, self.layer_num) | |
def forward(self, x_in_t: Tensor, h_x: Tensor, c_x: Tensor) -> Tuple [Tensor, Tensor]: | |
if h_x is None: | |
h_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device) | |
if c_x is None: | |
c_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device) | |
output, (h_x_3d, c_x_3d) = self.custom_lstmcell(x_in_t, (h_x, c_x)) | |
return (h_x_3d, c_x_3d) | |
class AutoRegressiveLSTMJITCompatV24(nn.Module): | |
_constants__ = ['embd', 'layers', 'norms_h', 'norms_c', 'pred_class'] | |
def __init__(self, num_embeddings: int, embd_size: int, hidden_size:int, layer_num=1, | |
device: str = "cpu"): | |
super(AutoRegressiveLSTMJITCompatV24, self).__init__() | |
self.hidden_size: int = hidden_size | |
self.n_layers: int = layer_num | |
self.embd_size: int = embd_size | |
self.num_embeddings: int = num_embeddings | |
self.device: str = device | |
self.embd: nn.Embedding = nn.Embedding(num_embeddings, embd_size) | |
self.layers: ModuleList = self.initialize_LSTMCell_layers(embd_size, hidden_size, layer_num) | |
self.norms_h: ModuleList = self.initialize_lstm_hidden_layer_norm(hidden_size, layer_num) | |
self.norms_c: ModuleList = self.initialize_lstm_context_layer_norm(hidden_size, layer_num) | |
self.pred_class: Sequential = self.initialize_pred_class_seq_layer(hidden_size, num_embeddings) | |
@torch.jit.export | |
def set_device(self, device: str) -> None: | |
if device is not None and torch.jit.isinstance(device, str): | |
self.device = device | |
@torch.jit.export | |
def initHiddenStates(self, B: int) -> List[Tensor]: | |
""" | |
Creates an initial hidden state list for the RNN layers. | |
B: the batch size for the hidden states. | |
""" | |
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers) | |
@torch.jit.export | |
def initCellStates(self, B: int) -> List[Tensor]: | |
""" | |
Creates an initial cell state list for the RNN layers. | |
B: the batch size for the hidden states. | |
""" | |
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers) | |
def build_init_layers(self, batch_size:int , hidden_size:int, device:str, layer_num: int): | |
init_layers = [(torch.zeros(batch_size, hidden_size, dtype=torch.float, device=device)) for _ in range( | |
layer_num)] | |
return init_layers | |
def initialize_LSTMCell_layers(self, embd_size: int, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, AutoRegressiveLSTMCell, | |
[embd_size, hidden_size, self.device], | |
[hidden_size, hidden_size, self.device]) | |
def initialize_lstm_hidden_layer_norm(self, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, nn.LayerNorm, | |
[hidden_size], | |
[hidden_size]) | |
def initialize_lstm_context_layer_norm(self, hidden_size: int, layer_num: int): | |
return self.init_stacked_lstm(layer_num, nn.LayerNorm, | |
[hidden_size], | |
[hidden_size]) | |
def initialize_pred_class_seq_layer(self, hidden_size: int, num_embeddings: int): | |
seq_layer = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), | |
nn.LayerNorm(hidden_size), | |
nn.GELU(), | |
nn.Linear(hidden_size, num_embeddings) # (B, *. D) -> B(B, *, VocabSize) | |
) | |
return seq_layer | |
def init_stacked_lstm(self, num_layers, layer_type, first_layer_args, other_layer_args): | |
layers = [layer_type(*first_layer_args)] + [layer_type(*other_layer_args) | |
for _ in range(num_layers - 1)] | |
return nn.ModuleList(layers) | |
def step(self, x_in: Tensor, h_prevs: Optional[List[Tensor]], | |
c_prevs: Optional[List[Tensor]]) -> Tensor: | |
if len(x_in.shape) == 1: | |
embd_x_in = self.embd(x_in) # now (B, D) | |
else: | |
embd_x_in = x_in | |
if h_prevs is None: | |
h_prevs = self.initHiddenStates(x_in.shape[0]) | |
if c_prevs is None: | |
c_prevs = self.initCellStates(x_in.shape[0]) | |
assert h_prevs is not None and c_prevs is not None, f"h_prevs or c_prevs should not be none for building h, c layers for seed" | |
for i, (cur_layer, cur_norm_h, cur_norm_c) in enumerate(zip(self.layers, self.norms_h, self.norms_c)): | |
if h_prevs is not None: | |
h_prev = h_prevs[i] | |
if c_prevs is not None: | |
c_prev = c_prevs[i] | |
assert h_prev is not None and c_prev is not None, "h_prev or c_prev is None" | |
h_each_step, c_each_step = cur_layer(embd_x_in, h_prev, c_prev) | |
h_each_step_n = cur_norm_h(h_each_step) | |
c_each_step_n = cur_norm_c(c_each_step) | |
h_prevs[i] = h_each_step_n | |
c_prevs[i] = c_each_step_n | |
embd_x_in = h_each_step_n | |
return self.pred_class(h_each_step_n) | |
def forward(self, | |
sampling_with_seed: Tensor, seed_len: Tensor = torch.tensor([5], dtype=torch.int64), | |
temperature: Tensor = torch.tensor([0.76],dtype=torch.float)) -> Tuple[Tensor, Tensor]: | |
h_prevs : List[Tensor] = self.initHiddenStates(1) | |
c_prevs: List[Tensor] = self.initCellStates(1) | |
predict_ts_seq_tensor: Tensor = torch.clone(sampling_with_seed) | |
h_out: Tensor = torch.zeros((1, self.num_embeddings), dtype=torch.float) | |
max_pred_len : int = sampling_with_seed.size(1) | |
cur_index : int = 0 | |
while cur_index < seed_len.select(0, 0): | |
embd_with_seed_time_t: Tensor = self.embd(torch.select(predict_ts_seq_tensor, 1, cur_index)) | |
h_out = self.step(embd_with_seed_time_t, h_prevs=h_prevs, c_prevs=c_prevs) | |
cur_index = cur_index + 1 | |
while cur_index < max_pred_len: | |
hhout_modulated = torch.div(h_out, temperature) | |
h_max = F.softmax(hhout_modulated, dim=-1) | |
reshaped = torch.reshape(h_max, (-1, self.num_embeddings)) | |
next_tokens = torch.multinomial(reshaped, num_samples=1) | |
next_tokens_max: Tensor = torch.max(next_tokens) | |
predict_ts_seq_tensor[:, cur_index] = next_tokens_max | |
embd_predict_ts_time_t = self.embd(torch.select(predict_ts_seq_tensor, 1, cur_index)) | |
h_out = self.step(embd_predict_ts_time_t, h_prevs=h_prevs, c_prevs=c_prevs) | |
cur_index = cur_index + 1 | |
return (predict_ts_seq_tensor, torch.tensor([cur_index], dtype=torch.int64)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment