Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save eleganceinsimplicity/90cb789017272a30073ac11acc70eac7 to your computer and use it in GitHub Desktop.
Save eleganceinsimplicity/90cb789017272a30073ac11acc70eac7 to your computer and use it in GitHub Desktop.
from typing import List, Tuple, Optional
import torch
from torch import nn, Tensor
from torch.nn import ModuleList, Sequential
import torch.nn.functional as F
from typing_extensions import Final
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
Using torch.nn.LSTM module as a single layer, 1 seq length LSTMCell
Pytorch -> Torchscript -> ONNX export does not exist for LSTMCell,
however, LSTM module is supported
class AutoRegressiveLSTMCell(nn.Module):
_constants__ = ['custom_lstmcell']
def __init__(self, embd_size:int, hidden_size: int,
device: str = "cpu"):
super(AutoRegressiveLSTMCell, self).__init__()
self.layer_num: int = 1
self.seq_len: int = 1
self.embd_size: int = embd_size
self.hidden_size:int = hidden_size
self.device: str = device
self.custom_lstmcell = nn.LSTM(embd_size, hidden_size, self.layer_num)
def forward(self, x_in_t: Tensor, h_x: Tensor, c_x: Tensor) -> Tuple [Tensor, Tensor]:
if h_x is None:
h_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device)
if c_x is None:
c_x = torch.zeros((x_in_t.size(0), self.hidden_size), dtype=torch.float, device=self.device)
output, (h_x_3d, c_x_3d) = self.custom_lstmcell(x_in_t, (h_x, c_x))
return (h_x_3d, c_x_3d)
class AutoRegressiveLSTMJITCompatV24(nn.Module):
_constants__ = ['embd', 'layers', 'norms_h', 'norms_c', 'pred_class']
def __init__(self, num_embeddings: int, embd_size: int, hidden_size:int, layer_num=1,
device: str = "cpu"):
super(AutoRegressiveLSTMJITCompatV24, self).__init__()
self.hidden_size: int = hidden_size
self.n_layers: int = layer_num
self.embd_size: int = embd_size
self.num_embeddings: int = num_embeddings
self.device: str = device
self.embd: nn.Embedding = nn.Embedding(num_embeddings, embd_size)
self.layers: ModuleList = self.initialize_LSTMCell_layers(embd_size, hidden_size, layer_num)
self.norms_h: ModuleList = self.initialize_lstm_hidden_layer_norm(hidden_size, layer_num)
self.norms_c: ModuleList = self.initialize_lstm_context_layer_norm(hidden_size, layer_num)
self.pred_class: Sequential = self.initialize_pred_class_seq_layer(hidden_size, num_embeddings)
def set_device(self, device: str) -> None:
if device is not None and torch.jit.isinstance(device, str):
self.device = device
def initHiddenStates(self, B: int) -> List[Tensor]:
Creates an initial hidden state list for the RNN layers.
B: the batch size for the hidden states.
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
def initCellStates(self, B: int) -> List[Tensor]:
Creates an initial cell state list for the RNN layers.
B: the batch size for the hidden states.
return self.build_init_layers(B, self.hidden_size, self.device, self.n_layers)
def build_init_layers(self, batch_size:int , hidden_size:int, device:str, layer_num: int):
init_layers = [(torch.zeros(batch_size, hidden_size, dtype=torch.float, device=device)) for _ in range(
return init_layers
def initialize_LSTMCell_layers(self, embd_size: int, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, AutoRegressiveLSTMCell,
[embd_size, hidden_size, self.device],
[hidden_size, hidden_size, self.device])
def initialize_lstm_hidden_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
def initialize_lstm_context_layer_norm(self, hidden_size: int, layer_num: int):
return self.init_stacked_lstm(layer_num, nn.LayerNorm,
def initialize_pred_class_seq_layer(self, hidden_size: int, num_embeddings: int):
seq_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, num_embeddings) # (B, *. D) -> B(B, *, VocabSize)
return seq_layer
def init_stacked_lstm(self, num_layers, layer_type, first_layer_args, other_layer_args):
layers = [layer_type(*first_layer_args)] + [layer_type(*other_layer_args)
for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
def step(self, x_in: Tensor, h_prevs: Optional[List[Tensor]],
c_prevs: Optional[List[Tensor]]) -> Tensor:
if len(x_in.shape) == 1:
embd_x_in = self.embd(x_in) # now (B, D)
embd_x_in = x_in
if h_prevs is None:
h_prevs = self.initHiddenStates(x_in.shape[0])
if c_prevs is None:
c_prevs = self.initCellStates(x_in.shape[0])
assert h_prevs is not None and c_prevs is not None, f"h_prevs or c_prevs should not be none for building h, c layers for seed"
for i, (cur_layer, cur_norm_h, cur_norm_c) in enumerate(zip(self.layers, self.norms_h, self.norms_c)):
if h_prevs is not None:
h_prev = h_prevs[i]
if c_prevs is not None:
c_prev = c_prevs[i]
assert h_prev is not None and c_prev is not None, "h_prev or c_prev is None"
h_each_step, c_each_step = cur_layer(embd_x_in, h_prev, c_prev)
h_each_step_n = cur_norm_h(h_each_step)
c_each_step_n = cur_norm_c(c_each_step)
h_prevs[i] = h_each_step_n
c_prevs[i] = c_each_step_n
embd_x_in = h_each_step_n
return self.pred_class(h_each_step_n)
def forward(self,
sampling_with_seed: Tensor, seed_len: Tensor = torch.tensor([5], dtype=torch.int64),
temperature: Tensor = torch.tensor([0.76],dtype=torch.float)) -> Tuple[Tensor, Tensor]:
h_prevs : List[Tensor] = self.initHiddenStates(1)
c_prevs: List[Tensor] = self.initCellStates(1)
predict_ts_seq_tensor: Tensor = torch.clone(sampling_with_seed)
h_out: Tensor = torch.zeros((1, self.num_embeddings), dtype=torch.float)
max_pred_len : int = sampling_with_seed.size(1)
cur_index : int = 0
while cur_index <, 0):
embd_with_seed_time_t: Tensor = self.embd(, 1, cur_index))
h_out = self.step(embd_with_seed_time_t, h_prevs=h_prevs, c_prevs=c_prevs)
cur_index = cur_index + 1
while cur_index < max_pred_len:
hhout_modulated = torch.div(h_out, temperature)
h_max = F.softmax(hhout_modulated, dim=-1)
reshaped = torch.reshape(h_max, (-1, self.num_embeddings))
next_tokens = torch.multinomial(reshaped, num_samples=1)
next_tokens_max: Tensor = torch.max(next_tokens)
predict_ts_seq_tensor[:, cur_index] = next_tokens_max
embd_predict_ts_time_t = self.embd(, 1, cur_index))
h_out = self.step(embd_predict_ts_time_t, h_prevs=h_prevs, c_prevs=c_prevs)
cur_index = cur_index + 1
return (predict_ts_seq_tensor, torch.tensor([cur_index], dtype=torch.int64))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment