Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save eleganceinsimplicity/90cb789017272a30073ac11acc70eac7 to your computer and use it in GitHub Desktop.
Save eleganceinsimplicity/90cb789017272a30073ac11acc70eac7 to your computer and use it in GitHub Desktop.
custom_ts_lstm_torchscript_to_onnx_compat_model.py
from typing import List, Tuple, Optional
import torch
from torch import nn, Tensor
from torch.nn import ModuleList, Sequential
import torch.nn.functional as F
try:
from typing_extensions import Final
except:
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
"""
Using torch.nn.LSTM module as a single layer, 1 seq length LSTMCell
Pytorch -> Torchscript -> ONNX export does not exist for LSTMCell,
however, LSTM module is supported
https://github.com/onnx/onnx/issues/3597#issuecomment-883835426
"""
class AutoRegressiveLSTMCell(nn.Module):
_constants__ = ['custom_lstmcell']
def __init__(self, embd_size:int, hidden_size: int,
device: str = "cpu"):
super(AutoRegressiveLSTMCell, self).__init__()
self.layer_num: int = 1
self.seq_len: int = 1
self.embd_size: int = embd_size
self.hidden_size:int = hidden_size
self.device: str = device
self.custom_lstmcell = nn.LSTM(embd_size, hidden_size, self.layer_num)
def forward(self, x_in_t: Tensor, h_x: Tensor, c_x: Tensor) -> Tuple [Tensor, Tensor]:
if h_x is None:
h_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device)
if c_x is None:
c_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device)
output, (h_x_3d, c_x_3d) = self.custom_lstmcell(x_in_t, (h_x, c_x))
return (h_x_3d, c_x_3d)
class AutoRegressiveLSTMJITCompatV24(nn.Module):
_constants__ = ['embd', 'layers', 'norms_h', 'norms_c', 'pred_class']
def __init__(self, num_embeddings: int, embd_size: int, hidden_size:int, layer_num=1,
device: str = "cpu"):
super(AutoRegressiveLSTMJITCompatV24, self).__init__()
self.hidden_size: int = hidden_size
self.n_layers: int = layer_num
self.embd_size: int = embd_size
self.num_embeddings: int = num_embeddings
self.device: str = device
self.embd: nn.Embedding = nn.Embedding(num_embeddings, embd_size)
self.layers: ModuleList = self.initialize_LSTMCell_layers(embd_size, hidden_size, layer_num)
self.norms_h: ModuleList = self.initialize_lstm_hidden_layer_norm(hidden_size, layer_num)
self.norms_c: ModuleList = self.initialize_lstm_context_layer_norm(hidden_size, layer_num)
self.pred_class: Sequential = self.initialize_pred_class_seq_layer(hidden_size, num_embeddings)
@torch.jit.export
def set_device(self, device: str) -> None:
if device is not None and torch.jit.isinstance(device, str):
self.device = device
@torch.jit.export
def initHiddenStates(self, B: int) -> List[Tensor]:
"""
Creates an initial hidden state list for the RNN layers.
B: the batch size for the hidden states.
"""
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
@torch.jit.export
def initCellStates(self, B: int) -> List[Tensor]:
"""
Creates an initial cell state list for the RNN layers.
B: the batch size for the hidden states.
"""
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
def build_init_layers(self, batch_size:int , hidden_size:int, device:str, layer_num: int):
init_layers = [(torch.zeros(batch_size, hidden_size, dtype=torch.float, device=device)) for _ in range(
layer_num)]
return init_layers
def initialize_LSTMCell_layers(self, embd_size: int, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, AutoRegressiveLSTMCell,
[embd_size, hidden_size, self.device],
[hidden_size, hidden_size, self.device])
def initialize_lstm_hidden_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
[hidden_size],
[hidden_size])
def initialize_lstm_context_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
[hidden_size],
[hidden_size])
def initialize_pred_class_seq_layer(self, hidden_size: int, num_embeddings: int):
seq_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Linear(hidden_size, num_embeddings) # (B, *. D) -> B(B, *, VocabSize)
)
return seq_layer
def init_stacked_lstm(self, num_layers, layer_type, first_layer_args, other_layer_args):
layers = [layer_type(*first_layer_args)] + [layer_type(*other_layer_args)
for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
def step(self, x_in: Tensor, h_prevs: Optional[List[Tensor]],
c_prevs: Optional[List[Tensor]]) -> Tensor:
if len(x_in.shape) == 1:
embd_x_in = self.embd(x_in) # now (B, D)
else:
embd_x_in = x_in
if h_prevs is None:
h_prevs = self.initHiddenStates(x_in.shape[0])
if c_prevs is None:
c_prevs = self.initCellStates(x_in.shape[0])
assert h_prevs is not None and c_prevs is not None, f"h_prevs or c_prevs should not be none for building h, c layers for seed"
for i, (cur_layer, cur_norm_h, cur_norm_c) in enumerate(zip(self.layers, self.norms_h, self.norms_c)):
if h_prevs is not None:
h_prev = h_prevs[i]
if c_prevs is not None:
c_prev = c_prevs[i]
assert h_prev is not None and c_prev is not None, "h_prev or c_prev is None"
h_each_step, c_each_step = cur_layer(embd_x_in, h_prev, c_prev)
h_each_step_n = cur_norm_h(h_each_step)
c_each_step_n = cur_norm_c(c_each_step)
h_prevs[i] = h_each_step_n
c_prevs[i] = c_each_step_n
embd_x_in = h_each_step_n
return self.pred_class(h_each_step_n)
def forward(self,
sampling_with_seed: Tensor, seed_len: Tensor = torch.tensor([5], dtype=torch.int64),
temperature: Tensor = torch.tensor([0.76],dtype=torch.float)) -> Tuple[Tensor, Tensor]:
h_prevs : List[Tensor] = self.initHiddenStates(1)
c_prevs: List[Tensor] = self.initCellStates(1)
predict_ts_seq_tensor: Tensor = torch.clone(sampling_with_seed)
h_out: Tensor = torch.zeros((1, self.num_embeddings), dtype=torch.float)
max_pred_len : int = sampling_with_seed.size(1)
cur_index : int = 0
while cur_index < seed_len.select(0, 0):
embd_with_seed_time_t: Tensor = self.embd(torch.select(predict_ts_seq_tensor, 1, cur_index))
h_out = self.step(embd_with_seed_time_t, h_prevs=h_prevs, c_prevs=c_prevs)
cur_index = cur_index + 1
while cur_index < max_pred_len:
hhout_modulated = torch.div(h_out, temperature)
h_max = F.softmax(hhout_modulated, dim=-1)
reshaped = torch.reshape(h_max, (-1, self.num_embeddings))
next_tokens = torch.multinomial(reshaped, num_samples=1)
next_tokens_max: Tensor = torch.max(next_tokens)
predict_ts_seq_tensor[:, cur_index] = next_tokens_max
embd_predict_ts_time_t = self.embd(torch.select(predict_ts_seq_tensor, 1, cur_index))
h_out = self.step(embd_predict_ts_time_t, h_prevs=h_prevs, c_prevs=c_prevs)
cur_index = cur_index + 1
return (predict_ts_seq_tensor, torch.tensor([cur_index], dtype=torch.int64))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment