Skip to content

Instantly share code, notes, and snippets.

@motokimura
Created March 21, 2023 10:51
Show Gist options
  • Save motokimura/f0caa5a6f6606a91fa18070cf7d29eb5 to your computer and use it in GitHub Desktop.
Save motokimura/f0caa5a6f6606a91fa18070cf7d29eb5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0+cpu\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# LSTM parameters\n",
"input_size = 10\n",
"hidden_size = 20\n",
"num_layers = 2\n",
"\n",
"# input parameters\n",
"sequence_length = 5\n",
"batch_size = 3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"input = torch.randn(sequence_length, batch_size, input_size)\n",
"h0 = torch.randn(num_layers, batch_size, hidden_size)\n",
"c0 = torch.randn(num_layers, batch_size, hidden_size)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"lstm = nn.LSTM(input_size, hidden_size, num_layers)\n",
"\n",
"output, (hn, cn) = lstm(input, (h0, c0))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============\n",
"verbose: False, log level: Level.ERROR\n",
"======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/motoki_kimura/.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:4476: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n",
" warnings.warn(\n"
]
}
],
"source": [
"input_names = [\"input\", \"h_0\", \"c_0\"]\n",
"output_names = [\"output\", \"h_n\", \"c_n\"]\n",
"torch.onnx.export(lstm, (input, (h0, c0)), \"lstm.onnx\", input_names=input_names, output_names=output_names)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class LSTMCells(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_layers):\n",
" super().__init__()\n",
" self.cells = [nn.LSTMCell(input_size, hidden_size)]\n",
" for i in range(1, num_layers):\n",
" self.cells.append(nn.LSTMCell(hidden_size, hidden_size))\n",
" self.cells = nn.ModuleList(self.cells)\n",
" self.num_layers = num_layers\n",
"\n",
" def forward(self, x, h_prev, c_prev):\n",
" # x: [batch_size, input_size]\n",
" # h_prev: [num_layers, batch_size, hidden_size]\n",
" # c_prev: [num_layers, batch_size, hidden_size]\n",
"\n",
" h_next, c_next = [], []\n",
"\n",
" for i in range(self.num_layers):\n",
" h = x if (i == 0) else h\n",
" h, c = self.cells[i](h, (h_prev[i], c_prev[i]))\n",
" h_next.append(h)\n",
" c_next.append(c)\n",
"\n",
" h_next = torch.stack(h_next, dim=0)\n",
" c_next = torch.stack(c_next, dim=0)\n",
"\n",
" return h_next, c_next\n",
"\n",
"\n",
"class CustomLSTM(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_layers):\n",
" super().__init__()\n",
" self.m = LSTMCells(input_size, hidden_size, num_layers)\n",
"\n",
" def load_params(self, lstm: nn.LSTM):\n",
" num_layers = lstm.num_layers\n",
" cells = self.m.cells\n",
" assert len(cells) == num_layers\n",
" for i in range(num_layers):\n",
" cells[i].weight_ih.data = lstm.__getattr__(f'weight_ih_l{i}').data\n",
" cells[i].weight_hh.data = lstm.__getattr__(f'weight_hh_l{i}').data\n",
" cells[i].bias_ih.data = lstm.__getattr__(f'bias_ih_l{i}').data\n",
" cells[i].bias_hh.data = lstm.__getattr__(f'bias_hh_l{i}').data\n",
"\n",
" def forward(self, x, hx):\n",
" # x: [sequence_lengtn, batch_size, input_size]\n",
" # hx: [[num_layers, batch_size, hidden_size], [num_layers, batch_size, hidden_size]]\n",
" h, c = hx\n",
" output = []\n",
" for t in range(x.size()[0]):\n",
" h, c = self.m(x[t], h, c)\n",
" output.append(h[-1])\n",
"\n",
" return torch.stack(output, dim=0), (h, c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============\n",
"verbose: False, log level: Level.ERROR\n",
"======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
"\n"
]
}
],
"source": [
"\n",
"custom_lstm = CustomLSTM(input_size, hidden_size, num_layers)\n",
"\n",
"custom_lstm.load_params(lstm)\n",
"\n",
"torch.onnx.export(custom_lstm, (input, (h0, c0)), \"lstm_custom.onnx\", input_names=input_names, output_names=output_names)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.0431e-07, grad_fn=<MaxBackward1>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"o1 = lstm(input, (h0, c0))\n",
"\n",
"o2 = custom_lstm(input, (h0, c0))\n",
"\n",
"(o2[0] - o1[0]).abs().max()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment