Created
September 4, 2020 14:57
-
-
Save CookieBox26/31a1247c0e31d6109067229a151ead66 to your computer and use it in GitHub Desktop.
LSTM で足し算する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# LSTM で足し算する\n", | |
"\n", | |
"### 参考文献\n", | |
"\n", | |
"1. [[1803.01271]An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271) ;TCNの原論文。\n", | |
"1. [locuslab/TCN: Sequence modeling benchmarks and temporal convolutional networks](https://github.com/locuslab/TCN) ;TCNの原論文のリポジトリ。\n", | |
"1. https://pytorch.org/docs/master/generated/torch.nn.LSTM.html ;torch.nn.LSTM のリファレンス。\n", | |
"1. [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) ;[Kerasのドキュメント](https://keras.io/ja/layers/recurrent/)の「初期化時に忘却ゲートのバイアスに1を加えます.」の箇所からリンクがある論文。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<table width=\"100%\">\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">参考文献 1. の13ページ目の Table 3. に、この論文の提案手法である TCN のライバル手法にされている LSTM のコンフィギュレーションがありますよね。例えば1番最初の行は Adding Problem タスクの T=200 の場合ですが、n=2, Hidden=77, Dropout=0.0, Grad Clip=50, bias=5.0 とありますが、Dropout=0.0 はいいとして、それ以外は何を指しているのでしょう…?</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/2.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">最初の n=2 についてだけど、論文の6ページ目に the depth of the network n とあるよね。ただこれは TCN の TemporalBlock の積み重ね数だけど、LSTM でも同じ文字 n を使っているのは、LSTM ブロックの積み重ね数だと考えていいんじゃないかな。Hidden=77 は各 LSTM ブロックの入出力の次元数でいいと思うよ(但し1つ目の LSTM ブロックの入力次元数については入力データの次元数)。だって、著者のリポジトリで nhid とかかれているパラメータ、各 TemporalBlock の入出力次元数にセットされているからね。\n", | |
"<ul style=\"margin:0.3em 0\">\n", | |
"<li><a href=\"https://github.com/locuslab/TCN/blob/master/TCN/adding_problem/add_test.py#L33-L34\">https://github.com/locuslab/TCN/blob/master/TCN/adding_problem/add_test.py#L33-L34</a></li>\n", | |
"<li><a href=\"https://github.com/locuslab/TCN/blob/master/TCN/tcn.py#L55-L56\">https://github.com/locuslab/TCN/blob/master/TCN/tcn.py#L55-L56</a></li>\n", | |
"</ul>Grad Clip=50 は訓練時に勾配ベクトルのノルムが50をはみ出したら50になるように縮めるという意味だからちょっと置いておこう。bias=5.0 は論文の6ページ目に initial forget-gate bias とあるからこのことだと思う。忘却ゲートのバイアスの初期値が重要だという話は調べると結構出てくるね。きちんと話を追えていないけど、でも、仮にもし忘却ゲートの重みやバイアスの初期値がゼロだったら記憶セルの最適化って全く進まないよね。記憶セルなんてなかったんだって感じで学習が進んじゃう。だから、「記憶セルがあることが前提で学習しなさい」と伝えるためにわざと最初に値を入れておくという感じだと思う。…まあそれで、今回の LSTM を実装すると以下のようになるかな。torch.nn.LSTM は引数 num_layers に指定した数だけ積み重ねられるからこれを使おう。</td>\n", | |
"</tr>\n", | |
"</table>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ モデル\n", | |
"myLSTM(\n", | |
" (network): Sequential(\n", | |
" (lstm): LSTM(2, 77, num_layers=2)\n", | |
" (linear): Linear(in_features=77, out_features=1, bias=True)\n", | |
" )\n", | |
")\n", | |
"◆ 学習対象パラメータ\n", | |
"network.lstm.weight_ih_l0 torch.Size([308, 2])\n", | |
"network.lstm.weight_hh_l0 torch.Size([308, 77])\n", | |
"network.lstm.bias_ih_l0 torch.Size([308])\n", | |
"network.lstm.bias_hh_l0 torch.Size([308])\n", | |
"network.lstm.weight_ih_l1 torch.Size([308, 77])\n", | |
"network.lstm.weight_hh_l1 torch.Size([308, 77])\n", | |
"network.lstm.bias_ih_l1 torch.Size([308])\n", | |
"network.lstm.bias_hh_l1 torch.Size([308])\n", | |
"network.linear.weight torch.Size([1, 77])\n", | |
"network.linear.bias torch.Size([1])\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.nn as nn\n", | |
"from collections import OrderedDict\n", | |
"\n", | |
"def _debug_print(debug, *content):\n", | |
" if debug:\n", | |
" print(*content)\n", | |
"\n", | |
"class myLSTM(nn.Module):\n", | |
" def __init__(self,\n", | |
" input_size=2, # 足し算タスクなので入力は2次元\n", | |
" output_size=1, # 足し算タスクなので出力は1次元\n", | |
" num_layers=2, # LSTM ブロックの積み重ね数が2\n", | |
" d_hidden=77, # 各 LSTM ブロックの出力次元数が77\n", | |
" initial_forget_gate_bias=5.0, # 忘却ゲートのバイアスの初期値\n", | |
" dropout=0.0):\n", | |
" super(myLSTM, self).__init__()\n", | |
" self.num_layers = num_layers\n", | |
" self.d_hidden = d_hidden\n", | |
" self.layers = OrderedDict()\n", | |
" self.layers['lstm'] = nn.LSTM(input_size, d_hidden,\n", | |
" num_layers=num_layers,\n", | |
" dropout=dropout)\n", | |
" self.layers['linear'] = nn.Linear(d_hidden, output_size)\n", | |
" self.network = nn.Sequential(self.layers)\n", | |
" self.init_weights(initial_forget_gate_bias)\n", | |
"\n", | |
" def init_weights(self, initial_forget_gate_bias):\n", | |
" # 忘却ゲートのバイアスの初期値をセット\n", | |
" for i_layer in range(self.num_layers):\n", | |
" bias = getattr(self.layers['lstm'], f'bias_ih_l{i_layer}')\n", | |
" bias.data[self.d_hidden:(2*self.d_hidden)] = initial_forget_gate_bias\n", | |
" bias = getattr(self.layers['lstm'], f'bias_hh_l{i_layer}')\n", | |
" bias.data[self.d_hidden:(2*self.d_hidden)] = initial_forget_gate_bias\n", | |
" self.layers['linear'].weight.data.normal_(0, 0.01)\n", | |
"\n", | |
" def forward(self, x, hidden, debug=False):\n", | |
" _debug_print(debug, '========== forward ==========')\n", | |
" _debug_print(debug, x.size())\n", | |
" out, hidden = self.layers['lstm'](x, hidden)\n", | |
" _debug_print(debug, out.size())\n", | |
" _debug_print(debug, hidden[0].size())\n", | |
" _debug_print(debug, hidden[1].size())\n", | |
" x = self.layers['linear'](hidden[0][-1,:,:])\n", | |
" _debug_print(debug, x.size())\n", | |
" _debug_print(debug, '=============================')\n", | |
" return x, hidden\n", | |
"\n", | |
"model_lstm = myLSTM()\n", | |
"print('◆ モデル')\n", | |
"print(model_lstm)\n", | |
"print('◆ 学習対象パラメータ')\n", | |
"for name, param in model_lstm.named_parameters():\n", | |
" print(name.ljust(14), param.size())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<table width=\"100%\">\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">えっと、パラメータの次元数が、 308?</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/2.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">77 × 4 = 308 だね。2次元の入力を77次元にするには右からサイズ [77, 2] の行列をかければいいけど、LSTM は「入力ゲート」「忘却ゲート」「通常のRNNの重み」「出力ゲート」があるから 77 が4倍になる。あと LSTM のバイアスを表示してみるね。忘却ゲートのバイアスに該当する 77~153 次元目に 5.0 を代入しただけだけどこんなんでいいのかな…Keras のドキュメントには忘却ゲートのバイアスを1にして、逆に忘却ゲートのバイアス以外は0にするといいというようにあったけど、とりあえず論文に言及があるのは忘却ゲートのバイアスだけだから他は放っておいた。</td>\n", | |
"</tr>\n", | |
"</table>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Parameter containing:\n", | |
"tensor([ 4.9165e-02, 5.3896e-02, -4.8629e-02, 9.2515e-02, 2.1954e-02,\n", | |
" 9.4397e-02, 9.4592e-02, -6.4292e-02, 3.2555e-02, -2.1142e-02,\n", | |
" -8.1746e-02, 9.1270e-02, 5.8425e-02, 3.5523e-02, 8.8397e-02,\n", | |
" -7.2806e-02, -9.3471e-02, 4.9092e-02, 7.1229e-03, -5.1445e-02,\n", | |
" -8.1698e-02, -3.8696e-02, -2.5925e-02, -9.2030e-02, 1.0211e-01,\n", | |
" -9.0567e-02, -1.0435e-01, 1.0762e-01, 5.9898e-02, 3.2932e-02,\n", | |
" -5.2855e-02, 5.6225e-02, 4.4851e-02, -1.0331e-01, -7.9663e-02,\n", | |
" -9.0007e-02, 3.9228e-02, -4.3425e-02, -7.9913e-02, 9.2957e-02,\n", | |
" -2.9188e-02, -6.8715e-02, 6.6197e-02, -5.7450e-02, -5.3279e-02,\n", | |
" 4.1563e-02, -8.3667e-02, 7.3850e-02, -2.8193e-02, 6.4358e-02,\n", | |
" -4.1299e-02, -8.7524e-02, 7.3115e-02, -6.8227e-02, -1.9827e-02,\n", | |
" 9.8330e-02, -1.0648e-01, -4.2002e-02, 3.3780e-02, -3.0554e-02,\n", | |
" 4.5411e-02, 6.9738e-02, -4.2859e-02, -6.1650e-02, 4.4743e-02,\n", | |
" -2.4768e-02, -8.5655e-02, -2.8372e-02, -8.7389e-02, 1.0232e-01,\n", | |
" 3.9674e-02, -5.2236e-02, 1.0560e-01, -9.7303e-02, -7.2981e-02,\n", | |
" -9.2142e-02, -7.0825e-02, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00,\n", | |
" 5.0000e+00, 5.0000e+00, 5.0000e+00, 5.0000e+00, -6.2721e-02,\n", | |
" -5.8241e-02, -1.6867e-02, -4.4569e-02, -2.7921e-02, 7.4363e-02,\n", | |
" -4.2447e-02, -5.5075e-02, 1.6161e-02, -2.4220e-02, 7.7675e-02,\n", | |
" -6.0809e-02, 5.4841e-02, 8.6750e-02, 3.6424e-02, 2.8048e-02,\n", | |
" -5.3524e-02, 3.5246e-02, -1.0170e-01, 1.7603e-02, 1.8143e-02,\n", | |
" -8.9160e-02, 1.0085e-01, 1.0683e-01, -1.3512e-02, -5.2527e-02,\n", | |
" 6.8354e-02, -1.0891e-02, 2.8305e-02, 2.8064e-02, 1.7084e-02,\n", | |
" 7.6924e-02, -4.3729e-02, 1.1172e-01, -1.1196e-01, 4.4700e-02,\n", | |
" 8.4991e-02, -8.6713e-02, -1.3968e-02, 4.3707e-02, -8.8109e-02,\n", | |
" 6.2691e-02, 8.2722e-02, -6.4417e-02, -9.9627e-03, 4.9382e-02,\n", | |
" -3.7137e-03, -3.1076e-02, 9.7192e-02, -4.7260e-02, 7.8700e-02,\n", | |
" 9.8256e-02, -2.5228e-02, 8.1686e-02, 5.6391e-02, 3.2331e-02,\n", | |
" 8.2114e-02, -3.0936e-02, -8.8276e-02, 1.8504e-02, -9.9111e-02,\n", | |
" -1.2504e-02, 5.6677e-02, -2.6727e-02, 5.9095e-02, -9.7887e-02,\n", | |
" 7.5615e-02, 2.5467e-02, -2.1051e-02, -2.2644e-03, 3.6444e-03,\n", | |
" 1.1376e-01, 1.0347e-02, 8.6061e-02, -1.5429e-02, 9.6989e-02,\n", | |
" 1.3893e-02, -8.8932e-03, -9.1217e-02, -1.8536e-03, -1.1309e-01,\n", | |
" -5.9436e-02, 6.8963e-02, -7.6549e-02, -4.3315e-02, -1.0767e-01,\n", | |
" 1.3827e-02, 3.8573e-02, -3.2789e-02, -4.6074e-02, 4.1241e-02,\n", | |
" -7.4975e-02, -2.8728e-02, 3.2466e-02, -5.4221e-02, -2.2455e-02,\n", | |
" -8.3136e-02, 7.5368e-02, -6.2805e-02, -7.6103e-02, 1.8733e-03,\n", | |
" -7.1987e-02, -1.0046e-01, -4.7820e-02, 6.5614e-02, -6.4605e-03,\n", | |
" 1.1242e-01, 8.9668e-02, -5.1174e-02, -5.9855e-02, -8.3304e-02,\n", | |
" 1.6023e-03, 1.7028e-03, 1.0380e-03, 2.0233e-02, -8.6717e-02,\n", | |
" 4.5389e-02, -8.3116e-02, 3.1499e-03, 1.8729e-02, -8.1673e-02,\n", | |
" -1.0322e-01, 2.6273e-02, 8.2414e-02, -9.9420e-02, 1.1180e-01,\n", | |
" 1.0178e-01, 3.9553e-02, 3.0256e-02, -3.9064e-02, -4.5343e-02,\n", | |
" -8.0867e-02, 3.5570e-02, 1.0058e-01, 1.0549e-01, 7.8596e-02,\n", | |
" -9.6782e-02, 7.9042e-02, -3.2222e-02, -7.1313e-02, -1.9761e-02,\n", | |
" 5.7132e-03, 7.6312e-02, -9.1598e-02, -8.6944e-02, 6.7028e-03,\n", | |
" -6.1358e-02, 9.3485e-04, 2.1562e-02, 2.2943e-02, 6.1359e-02,\n", | |
" -9.1821e-02, -2.2169e-03, -8.0288e-02], requires_grad=True)\n" | |
] | |
} | |
], | |
"source": [ | |
"print(model_lstm.layers['lstm'].bias_ih_l0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<table width=\"100%\">\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">それで、実際に学習できるんですか?</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/2.png\"></td>\n", | |
"<td style=\"vertical-align:top;text-align:left;\">たぶん以下のようになると思うけど(雑だけど)。忘却ゲートのバイアスが5の方が0のときよりロスは小さいみたい?</td>\n", | |
"</tr>\n", | |
"</table>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import torch\n", | |
"\n", | |
"from torch.autograd import Variable\n", | |
"torch.manual_seed(1)\n", | |
"\n", | |
"# TCN向けの足し算データを生成する関数\n", | |
"# ソース https://github.com/locuslab/TCN/blob/master/TCN/adding_problem/utils.py\n", | |
"# out: torch.Size([N, 2, seq_length]), torch.Size([N, 1])\n", | |
"def data_generator(N, seq_length):\n", | |
" X_num = torch.rand([N, 1, seq_length])\n", | |
" X_mask = torch.zeros([N, 1, seq_length])\n", | |
" Y = torch.zeros([N, 1])\n", | |
" for i in range(N):\n", | |
" positions = np.random.choice(seq_length, size=2, replace=False)\n", | |
" X_mask[i, 0, positions[0]] = 1\n", | |
" X_mask[i, 0, positions[1]] = 1\n", | |
" Y[i,0] = X_num[i, 0, positions[0]] + X_num[i, 0, positions[1]]\n", | |
" X = torch.cat((X_num, X_mask), dim=1)\n", | |
" return Variable(X), Variable(Y)\n", | |
"\n", | |
"# TCN向けのデータをLSTM向けに転置する関数\n", | |
"# out: torch.Size([seq_length, N, 2]), torch.Size([N, 1])\n", | |
"def copy_for_lstm(x):\n", | |
" x_ = x.clone().detach()\n", | |
" x_ = x_.transpose(0, 2).transpose(1, 2).contiguous()\n", | |
" return x_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"========== forward ==========\n", | |
"torch.Size([200, 100, 2])\n", | |
"torch.Size([200, 100, 77])\n", | |
"torch.Size([2, 100, 77])\n", | |
"torch.Size([2, 100, 77])\n", | |
"torch.Size([100, 1])\n", | |
"=============================\n", | |
"\n", | |
"◇ out = 足し算結果(ネットワークを学習していないので足し算にはなっていない、次元だけ確認)\n", | |
"torch.Size([100, 1])\n", | |
"tensor([[-0.0025],\n", | |
" [-0.0088],\n", | |
" [-0.0166],\n", | |
" [-0.0055],\n", | |
" [-0.0239],\n", | |
" [-0.0173],\n", | |
" [ 0.0040],\n", | |
" [-0.0158],\n", | |
" [-0.0021],\n", | |
" [-0.0226]], grad_fn=<SliceBackward>)\n", | |
"◇ hidden = 流した後の隠れ状態と記憶セル\n", | |
"torch.Size([2, 100, 77]) torch.Size([2, 100, 77])\n", | |
"(tensor([[[ 0.4744, -0.3614, -0.5132, ..., -0.4230, 0.4121, 0.4119],\n", | |
" [ 0.4572, -0.3583, -0.5197, ..., -0.4263, 0.4100, 0.3948],\n", | |
" [ 0.4590, -0.3679, -0.5241, ..., -0.4239, 0.4007, 0.3892],\n", | |
" ...,\n", | |
" [ 0.4665, -0.3798, -0.4826, ..., -0.4065, 0.3867, 0.4315],\n", | |
" [ 0.4696, -0.3296, -0.5443, ..., -0.4045, 0.3817, 0.4176],\n", | |
" [ 0.4625, -0.3428, -0.5251, ..., -0.4157, 0.4224, 0.4039]],\n", | |
"\n", | |
" [[ 0.5115, -0.5246, -0.5442, ..., -0.3999, -0.4695, 0.3655],\n", | |
" [ 0.5027, -0.5164, -0.5433, ..., -0.4205, -0.4633, 0.3725],\n", | |
" [ 0.5146, -0.4881, -0.5898, ..., -0.3399, -0.5100, 0.3577],\n", | |
" ...,\n", | |
" [ 0.4788, -0.5288, -0.5811, ..., -0.4078, -0.5186, 0.4486],\n", | |
" [ 0.4448, -0.5565, -0.5365, ..., -0.3549, -0.4890, 0.3322],\n", | |
" [ 0.4976, -0.5203, -0.5409, ..., -0.3922, -0.4817, 0.3674]]],\n", | |
" grad_fn=<StackBackward>), tensor([[[ 25.3851, -19.5185, -2.2148, ..., -7.0757, 21.8752, 21.2672],\n", | |
" [ 24.9314, -18.4482, -2.1745, ..., -7.2815, 21.8010, 21.4574],\n", | |
" [ 26.8274, -21.9711, -2.5544, ..., -6.7963, 23.2875, 20.9006],\n", | |
" ...,\n", | |
" [ 29.8533, -22.5231, -2.4401, ..., -3.5235, 24.2307, 21.8897],\n", | |
" [ 24.6162, -15.4336, -4.5720, ..., -5.4434, 18.8222, 23.8761],\n", | |
" [ 24.6132, -18.0139, -2.4366, ..., -7.3646, 21.3561, 21.7261]],\n", | |
"\n", | |
" [[ 17.5962, -28.9121, -15.6031, ..., -8.2359, -13.6865, 4.7207],\n", | |
" [ 17.1507, -29.2552, -15.5106, ..., -7.4294, -13.2211, 5.2530],\n", | |
" [ 17.0677, -22.5658, -15.9626, ..., -9.6145, -23.2147, 3.4630],\n", | |
" ...,\n", | |
" [ 28.1819, -33.5668, -22.2365, ..., -8.5320, -28.6113, 3.7873],\n", | |
" [ 6.8183, -14.3391, -26.0844, ..., -7.9752, -11.1114, 12.1419],\n", | |
" [ 17.1735, -25.4016, -15.4087, ..., -7.9439, -11.6768, 5.0952]]],\n", | |
" grad_fn=<StackBackward>))\n" | |
] | |
} | |
], | |
"source": [ | |
"# LSTMの隠れ状態と記憶セルの初期値を作成\n", | |
"num_layers = 2\n", | |
"batch_size = 100\n", | |
"d_hidden = 77\n", | |
"hidden0 = torch.zeros([num_layers, batch_size, d_hidden])\n", | |
"cell0 = torch.zeros([num_layers, batch_size, d_hidden])\n", | |
"\n", | |
"seq_length = 200\n", | |
"data = data_generator(batch_size, seq_length)\n", | |
"x = copy_for_lstm(data[0])\n", | |
"\n", | |
"hidden = (hidden0.clone().detach(), cell0.clone().detach())\n", | |
"out, hidden = model_lstm.forward(x, hidden, debug=True)\n", | |
"print('\\n◇ out = 足し算結果(ネットワークを学習していないので足し算にはなっていない、次元だけ確認)')\n", | |
"print(out.size())\n", | |
"print(out[:10])\n", | |
"print('◇ hidden = 流した後の隠れ状態と記憶セル')\n", | |
"print(hidden[0].size(), hidden[1].size())\n", | |
"print(hidden)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.optim as optim\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"\n", | |
"def train(model_lstm):\n", | |
" optimizer = optim.SGD(model_lstm.parameters(), lr=0.001)\n", | |
" grad_clip = 50.0\n", | |
" total_loss = 0\n", | |
"\n", | |
" for epoch in range(100):\n", | |
" optimizer.zero_grad()\n", | |
" # data = data_generator(batch_size, seq_length)\n", | |
" data = data_generator(100, 100)\n", | |
" x = copy_for_lstm(data[0])\n", | |
" \n", | |
" hidden = (hidden0.clone().detach(), cell0.clone().detach())\n", | |
" out, hidden = model_lstm.forward(x, hidden)\n", | |
"\n", | |
" loss = F.mse_loss(out, data[1])\n", | |
" loss.backward()\n", | |
" if grad_clip > 0:\n", | |
" torch.nn.utils.clip_grad_norm_(model_lstm.parameters(), grad_clip)\n", | |
" optimizer.step()\n", | |
" \n", | |
" total_loss += loss.item()\n", | |
"\n", | |
" print(total_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"101.7260605096817\n", | |
"26.77291202545166\n" | |
] | |
} | |
], | |
"source": [ | |
"model_lstm = myLSTM(initial_forget_gate_bias=0.0)\n", | |
"train(model_lstm)\n", | |
"\n", | |
"model_lstm = myLSTM(initial_forget_gate_bias=5.0)\n", | |
"train(model_lstm)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment