Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save i4kimura/469436243ec82757d3e98c01d3174c2d to your computer and use it in GitHub Desktop.
Save i4kimura/469436243ec82757d3e98c01d3174c2d to your computer and use it in GitHub Desktop.
ゼロから作るディープラーニング② 第6章 ゲート付きRNN
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6章 ゲート付きRNN\n",
"\n",
"5章のRNNは比較的シンプルなRNNであり、実際にはRNNといえばLSTMやGRUであることが多い。\n",
"- シンプルな5章のRNNは「エルマン(Elman)」と呼ばれる。\n",
"\n",
"## 6.1 RNNの問題点\n",
"\n",
"BPTT(Backpropagation Through Time)において、勾配消失もしくは勾配爆発が起きることが問題になる。\n",
"\n",
"### 6.1.1 RNNの復習\n",
"\n",
"過去の情報を記録される$h_t$はRNNの **隠れ状態(hidden state)** と呼ばれる。\n",
"\n",
"### 6.1.2 勾配消失もしくは勾配爆発\n",
"\n",
"RNNが学習するとき、遠いコンテキストに学習結果を伝える必要がある。\n",
"シンプルなRNNを使用すると、時間をさかのぼるに従って\n",
"- 勾配が小さくなってしまう(勾配消失)\n",
"- 勾配が大きくなってしまう(勾配爆発)\n",
"という状態が発生する。\n",
"\n",
"これは、RNNの逆伝搬において$\\tanh$が使用されており、この$\\tanh$を通過するたびに勾配は弱められることになる。\n",
"さらに、MatMulノードを通過するたびに勾配の大きさの変化は大きくなり、勾配爆発または勾配消失が発生する。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"N = 2 # ミニバッチサイズ\n",
"H = 3 # 隠れ状態ベクトルの次元数\n",
"T = 20 # 時系列データの長さ\n",
"\n",
"dh = np.ones((N, H))\n",
"np.random.seed(3) # 再現性のため乱数のシードを固定\n",
"Wh = np.random.randn(H, H)\n",
"\n",
"norm_list = []\n",
"for t in range(T):\n",
" dh = np.dot(dh, Wh.T)\n",
" norm = np.sqrt(np.sum(dh**2)) / N\n",
" norm_list.append(norm)\n",
" \n",
"plt.plot( norm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6.1.4 勾配爆発への対策\n",
"\n",
"**勾配クリッピング(gradients clipping)**と呼ばれる手法。しきい値を超えた場合は、値を小さくする。\n",
"\n",
"```python\n",
"def clip_grads(grads, max_norm):\n",
" total_norm = 0\n",
" for grad in grads:\n",
" total_norm += np.sum(grad ** 2)\n",
" total_norm = np.sqrt(total_norm)\n",
"\n",
" rate = max_norm / (total_norm + 1e-6)\n",
" if rate < 1:\n",
" for grad in grads:\n",
" grad *= rate\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.2 勾配消失とLSTM\n",
"\n",
"勾配消失を解決するためには、「ゲート付きRNN」を使用する必要がある。\n",
"\n",
"### 6.2.1 LSTMのインタフェース\n",
"\n",
"LSTMのインタフェースでは、$\\tanh(\\bf h_{t-1}W_h+x_tW_x+b)$という計算を、ひとつの$\\tanh$というノードで表す。\n",
"\n",
"- outputゲート\n",
"- forgetゲート\n",
"- 新しい記憶セル\n",
"- inputゲート\n",
"\n",
"を挿入する。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.3 LSTMの実装\n",
"\n",
"```python\n",
"class LSTM:\n",
"...\n",
" def forward(self, x, h_prev, c_prev):\n",
" Wx, Wh, b = self.params\n",
" N, H = h_prev.shape\n",
"\n",
" A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b\n",
"\n",
" f = A[:, :H]\n",
" g = A[:, H:2*H]\n",
" i = A[:, 2*H:3*H]\n",
" o = A[:, 3*H:]\n",
"\n",
" f = sigmoid(f)\n",
" g = np.tanh(g)\n",
" i = sigmoid(i)\n",
" o = sigmoid(o)\n",
"\n",
" c_next = f * c_prev + g * i\n",
" h_next = o * np.tanh(c_next)\n",
"\n",
" self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)\n",
" return h_next, c_next\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4 LSTMを使った言語モデル\n",
"\n",
"これまでのTimeRNNレイヤだった場所に、TimeLSTMを挿入する。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 1 | iter 1 / 1327 | time 0[s] | perplexity 10001.53\n",
"| epoch 1 | iter 21 / 1327 | time 4[s] | perplexity 3369.27\n",
"| epoch 1 | iter 41 / 1327 | time 8[s] | perplexity 1270.10\n",
"| epoch 1 | iter 61 / 1327 | time 12[s] | perplexity 985.67\n",
"| epoch 1 | iter 81 / 1327 | time 16[s] | perplexity 794.73\n",
"| epoch 1 | iter 101 / 1327 | time 20[s] | perplexity 680.45\n",
"| epoch 1 | iter 121 / 1327 | time 24[s] | perplexity 658.45\n",
"| epoch 1 | iter 141 / 1327 | time 29[s] | perplexity 619.76\n",
"| epoch 1 | iter 161 / 1327 | time 32[s] | perplexity 606.06\n",
"| epoch 1 | iter 181 / 1327 | time 36[s] | perplexity 584.49\n",
"| epoch 1 | iter 201 / 1327 | time 40[s] | perplexity 521.27\n",
"| epoch 1 | iter 221 / 1327 | time 44[s] | perplexity 503.49\n",
"| epoch 1 | iter 241 / 1327 | time 49[s] | perplexity 450.23\n",
"| epoch 1 | iter 261 / 1327 | time 53[s] | perplexity 478.97\n",
"| epoch 1 | iter 281 / 1327 | time 57[s] | perplexity 448.23\n",
"| epoch 1 | iter 301 / 1327 | time 61[s] | perplexity 393.59\n",
"| epoch 1 | iter 321 / 1327 | time 65[s] | perplexity 358.08\n",
"| epoch 1 | iter 341 / 1327 | time 69[s] | perplexity 404.20\n",
"| epoch 1 | iter 361 / 1327 | time 73[s] | perplexity 413.65\n",
"| epoch 1 | iter 381 / 1327 | time 77[s] | perplexity 344.77\n",
"| epoch 1 | iter 401 / 1327 | time 81[s] | perplexity 360.71\n",
"| epoch 1 | iter 421 / 1327 | time 85[s] | perplexity 348.61\n",
"| epoch 1 | iter 441 / 1327 | time 89[s] | perplexity 333.02\n",
"| epoch 1 | iter 461 / 1327 | time 93[s] | perplexity 326.15\n",
"| epoch 1 | iter 481 / 1327 | time 97[s] | perplexity 309.48\n",
"| epoch 1 | iter 501 / 1327 | time 101[s] | perplexity 320.86\n",
"| epoch 1 | iter 521 / 1327 | time 105[s] | perplexity 304.07\n",
"| epoch 1 | iter 541 / 1327 | time 109[s] | perplexity 317.89\n",
"| epoch 1 | iter 561 / 1327 | time 113[s] | perplexity 289.84\n",
"| epoch 1 | iter 581 / 1327 | time 118[s] | perplexity 260.17\n",
"| epoch 1 | iter 601 / 1327 | time 125[s] | perplexity 336.76\n",
"| epoch 1 | iter 621 / 1327 | time 131[s] | perplexity 317.84\n",
"| epoch 1 | iter 641 / 1327 | time 136[s] | perplexity 283.07\n",
"| epoch 1 | iter 661 / 1327 | time 140[s] | perplexity 271.58\n",
"| epoch 1 | iter 681 / 1327 | time 146[s] | perplexity 230.88\n",
"| epoch 1 | iter 701 / 1327 | time 152[s] | perplexity 250.57\n",
"| epoch 1 | iter 721 / 1327 | time 157[s] | perplexity 259.89\n",
"| epoch 1 | iter 741 / 1327 | time 161[s] | perplexity 221.75\n",
"| epoch 1 | iter 761 / 1327 | time 165[s] | perplexity 234.70\n",
"| epoch 1 | iter 781 / 1327 | time 169[s] | perplexity 219.92\n",
"| epoch 1 | iter 801 / 1327 | time 173[s] | perplexity 240.31\n",
"| epoch 1 | iter 821 / 1327 | time 177[s] | perplexity 224.74\n",
"| epoch 1 | iter 841 / 1327 | time 181[s] | perplexity 229.62\n",
"| epoch 1 | iter 861 / 1327 | time 186[s] | perplexity 222.55\n",
"| epoch 1 | iter 881 / 1327 | time 191[s] | perplexity 205.47\n",
"| epoch 1 | iter 901 / 1327 | time 196[s] | perplexity 256.40\n",
"| epoch 1 | iter 921 / 1327 | time 201[s] | perplexity 229.65\n",
"| epoch 1 | iter 941 / 1327 | time 207[s] | perplexity 229.20\n",
"| epoch 1 | iter 961 / 1327 | time 212[s] | perplexity 245.76\n",
"| epoch 1 | iter 981 / 1327 | time 218[s] | perplexity 230.29\n",
"| epoch 1 | iter 1001 / 1327 | time 223[s] | perplexity 193.78\n",
"| epoch 1 | iter 1021 / 1327 | time 228[s] | perplexity 226.22\n",
"| epoch 1 | iter 1041 / 1327 | time 233[s] | perplexity 209.86\n",
"| epoch 1 | iter 1061 / 1327 | time 238[s] | perplexity 198.70\n",
"| epoch 1 | iter 1081 / 1327 | time 243[s] | perplexity 168.97\n",
"| epoch 1 | iter 1101 / 1327 | time 248[s] | perplexity 192.36\n",
"| epoch 1 | iter 1121 / 1327 | time 253[s] | perplexity 229.63\n",
"| epoch 1 | iter 1141 / 1327 | time 258[s] | perplexity 206.56\n",
"| epoch 1 | iter 1161 / 1327 | time 262[s] | perplexity 199.81\n",
"| epoch 1 | iter 1181 / 1327 | time 267[s] | perplexity 191.34\n",
"| epoch 1 | iter 1201 / 1327 | time 272[s] | perplexity 163.47\n",
"| epoch 1 | iter 1221 / 1327 | time 277[s] | perplexity 161.37\n",
"| epoch 1 | iter 1241 / 1327 | time 281[s] | perplexity 187.98\n",
"| epoch 1 | iter 1261 / 1327 | time 286[s] | perplexity 172.95\n",
"| epoch 1 | iter 1281 / 1327 | time 291[s] | perplexity 180.20\n",
"| epoch 1 | iter 1301 / 1327 | time 296[s] | perplexity 222.08\n",
"| epoch 1 | iter 1321 / 1327 | time 301[s] | perplexity 210.05\n",
"| epoch 2 | iter 1 / 1327 | time 302[s] | perplexity 223.75\n",
"| epoch 2 | iter 21 / 1327 | time 307[s] | perplexity 203.74\n",
"| epoch 2 | iter 41 / 1327 | time 312[s] | perplexity 189.99\n",
"| epoch 2 | iter 61 / 1327 | time 317[s] | perplexity 177.51\n",
"| epoch 2 | iter 81 / 1327 | time 322[s] | perplexity 160.12\n",
"| epoch 2 | iter 101 / 1327 | time 327[s] | perplexity 152.91\n",
"| epoch 2 | iter 121 / 1327 | time 332[s] | perplexity 160.44\n",
"| epoch 2 | iter 141 / 1327 | time 337[s] | perplexity 178.88\n",
"| epoch 2 | iter 161 / 1327 | time 341[s] | perplexity 193.38\n",
"| epoch 2 | iter 181 / 1327 | time 346[s] | perplexity 199.88\n",
"| epoch 2 | iter 201 / 1327 | time 350[s] | perplexity 184.36\n",
"| epoch 2 | iter 221 / 1327 | time 354[s] | perplexity 183.61\n",
"| epoch 2 | iter 241 / 1327 | time 359[s] | perplexity 177.71\n",
"| epoch 2 | iter 261 / 1327 | time 364[s] | perplexity 185.46\n",
"| epoch 2 | iter 281 / 1327 | time 368[s] | perplexity 185.07\n",
"| epoch 2 | iter 301 / 1327 | time 373[s] | perplexity 166.49\n",
"| epoch 2 | iter 321 / 1327 | time 377[s] | perplexity 139.45\n",
"| epoch 2 | iter 341 / 1327 | time 382[s] | perplexity 173.59\n",
"| epoch 2 | iter 361 / 1327 | time 386[s] | perplexity 196.56\n",
"| epoch 2 | iter 381 / 1327 | time 390[s] | perplexity 152.65\n",
"| epoch 2 | iter 401 / 1327 | time 395[s] | perplexity 167.52\n",
"| epoch 2 | iter 421 / 1327 | time 399[s] | perplexity 153.49\n",
"| epoch 2 | iter 441 / 1327 | time 404[s] | perplexity 162.79\n",
"| epoch 2 | iter 461 / 1327 | time 408[s] | perplexity 158.78\n",
"| epoch 2 | iter 481 / 1327 | time 412[s] | perplexity 156.82\n",
"| epoch 2 | iter 501 / 1327 | time 417[s] | perplexity 168.73\n",
"| epoch 2 | iter 521 / 1327 | time 421[s] | perplexity 174.60\n",
"| epoch 2 | iter 541 / 1327 | time 426[s] | perplexity 175.13\n",
"| epoch 2 | iter 561 / 1327 | time 430[s] | perplexity 154.33\n",
"| epoch 2 | iter 581 / 1327 | time 435[s] | perplexity 138.94\n",
"| epoch 2 | iter 601 / 1327 | time 439[s] | perplexity 190.57\n",
"| epoch 2 | iter 621 / 1327 | time 444[s] | perplexity 181.79\n",
"| epoch 2 | iter 641 / 1327 | time 448[s] | perplexity 164.16\n",
"| epoch 2 | iter 661 / 1327 | time 452[s] | perplexity 154.69\n",
"| epoch 2 | iter 681 / 1327 | time 457[s] | perplexity 129.25\n",
"| epoch 2 | iter 701 / 1327 | time 461[s] | perplexity 149.68\n",
"| epoch 2 | iter 721 / 1327 | time 466[s] | perplexity 160.68\n",
"| epoch 2 | iter 741 / 1327 | time 470[s] | perplexity 132.86\n",
"| epoch 2 | iter 761 / 1327 | time 475[s] | perplexity 130.31\n",
"| epoch 2 | iter 781 / 1327 | time 479[s] | perplexity 135.17\n",
"| epoch 2 | iter 801 / 1327 | time 484[s] | perplexity 147.12\n",
"| epoch 2 | iter 821 / 1327 | time 488[s] | perplexity 143.79\n",
"| epoch 2 | iter 841 / 1327 | time 493[s] | perplexity 143.31\n",
"| epoch 2 | iter 861 / 1327 | time 498[s] | perplexity 144.79\n",
"| epoch 2 | iter 881 / 1327 | time 503[s] | perplexity 131.04\n",
"| epoch 2 | iter 901 / 1327 | time 508[s] | perplexity 165.02\n",
"| epoch 2 | iter 921 / 1327 | time 513[s] | perplexity 148.06\n",
"| epoch 2 | iter 941 / 1327 | time 518[s] | perplexity 153.83\n",
"| epoch 2 | iter 961 / 1327 | time 523[s] | perplexity 165.04\n",
"| epoch 2 | iter 981 / 1327 | time 528[s] | perplexity 153.31\n",
"| epoch 2 | iter 1001 / 1327 | time 533[s] | perplexity 132.15\n",
"| epoch 2 | iter 1021 / 1327 | time 538[s] | perplexity 156.56\n",
"| epoch 2 | iter 1041 / 1327 | time 543[s] | perplexity 141.93\n",
"| epoch 2 | iter 1061 / 1327 | time 548[s] | perplexity 128.13\n",
"| epoch 2 | iter 1081 / 1327 | time 554[s] | perplexity 110.03\n",
"| epoch 2 | iter 1101 / 1327 | time 559[s] | perplexity 119.79\n",
"| epoch 2 | iter 1121 / 1327 | time 563[s] | perplexity 152.99\n",
"| epoch 2 | iter 1141 / 1327 | time 568[s] | perplexity 141.46\n",
"| epoch 2 | iter 1161 / 1327 | time 572[s] | perplexity 133.02\n",
"| epoch 2 | iter 1181 / 1327 | time 577[s] | perplexity 133.30\n",
"| epoch 2 | iter 1201 / 1327 | time 582[s] | perplexity 112.68\n",
"| epoch 2 | iter 1221 / 1327 | time 587[s] | perplexity 109.20\n",
"| epoch 2 | iter 1241 / 1327 | time 591[s] | perplexity 130.53\n",
"| epoch 2 | iter 1261 / 1327 | time 596[s] | perplexity 124.27\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 2 | iter 1281 / 1327 | time 600[s] | perplexity 122.96\n",
"| epoch 2 | iter 1301 / 1327 | time 605[s] | perplexity 157.10\n",
"| epoch 2 | iter 1321 / 1327 | time 609[s] | perplexity 153.77\n",
"| epoch 3 | iter 1 / 1327 | time 611[s] | perplexity 159.04\n",
"| epoch 3 | iter 21 / 1327 | time 615[s] | perplexity 144.00\n",
"| epoch 3 | iter 41 / 1327 | time 620[s] | perplexity 135.04\n",
"| epoch 3 | iter 61 / 1327 | time 625[s] | perplexity 126.73\n",
"| epoch 3 | iter 81 / 1327 | time 629[s] | perplexity 116.52\n",
"| epoch 3 | iter 101 / 1327 | time 634[s] | perplexity 105.99\n",
"| epoch 3 | iter 121 / 1327 | time 638[s] | perplexity 116.02\n",
"| epoch 3 | iter 141 / 1327 | time 643[s] | perplexity 126.64\n",
"| epoch 3 | iter 161 / 1327 | time 647[s] | perplexity 142.06\n",
"| epoch 3 | iter 181 / 1327 | time 652[s] | perplexity 148.62\n",
"| epoch 3 | iter 201 / 1327 | time 656[s] | perplexity 141.37\n",
"| epoch 3 | iter 221 / 1327 | time 660[s] | perplexity 140.41\n",
"| epoch 3 | iter 241 / 1327 | time 665[s] | perplexity 135.36\n",
"| epoch 3 | iter 261 / 1327 | time 670[s] | perplexity 139.09\n",
"| epoch 3 | iter 281 / 1327 | time 674[s] | perplexity 141.31\n",
"| epoch 3 | iter 301 / 1327 | time 679[s] | perplexity 123.63\n",
"| epoch 3 | iter 321 / 1327 | time 683[s] | perplexity 101.11\n",
"| epoch 3 | iter 341 / 1327 | time 687[s] | perplexity 123.65\n",
"| epoch 3 | iter 361 / 1327 | time 692[s] | perplexity 151.25\n",
"| epoch 3 | iter 381 / 1327 | time 696[s] | perplexity 114.11\n",
"| epoch 3 | iter 401 / 1327 | time 701[s] | perplexity 129.45\n",
"| epoch 3 | iter 421 / 1327 | time 705[s] | perplexity 113.48\n",
"| epoch 3 | iter 441 / 1327 | time 710[s] | perplexity 123.47\n",
"| epoch 3 | iter 461 / 1327 | time 714[s] | perplexity 119.60\n",
"| epoch 3 | iter 481 / 1327 | time 719[s] | perplexity 118.64\n",
"| epoch 3 | iter 501 / 1327 | time 723[s] | perplexity 128.42\n",
"| epoch 3 | iter 521 / 1327 | time 728[s] | perplexity 138.72\n",
"| epoch 3 | iter 541 / 1327 | time 732[s] | perplexity 135.22\n",
"| epoch 3 | iter 561 / 1327 | time 737[s] | perplexity 117.48\n",
"| epoch 3 | iter 581 / 1327 | time 742[s] | perplexity 105.71\n",
"| epoch 3 | iter 601 / 1327 | time 746[s] | perplexity 147.57\n",
"| epoch 3 | iter 621 / 1327 | time 751[s] | perplexity 141.62\n",
"| epoch 3 | iter 641 / 1327 | time 755[s] | perplexity 129.66\n",
"| epoch 3 | iter 661 / 1327 | time 759[s] | perplexity 120.43\n",
"| epoch 3 | iter 681 / 1327 | time 764[s] | perplexity 99.94\n",
"| epoch 3 | iter 701 / 1327 | time 768[s] | perplexity 118.26\n",
"| epoch 3 | iter 721 / 1327 | time 773[s] | perplexity 126.15\n",
"| epoch 3 | iter 741 / 1327 | time 778[s] | perplexity 106.56\n",
"| epoch 3 | iter 761 / 1327 | time 782[s] | perplexity 103.61\n",
"| epoch 3 | iter 781 / 1327 | time 787[s] | perplexity 103.21\n",
"| epoch 3 | iter 801 / 1327 | time 791[s] | perplexity 114.00\n",
"| epoch 3 | iter 821 / 1327 | time 796[s] | perplexity 115.18\n",
"| epoch 3 | iter 841 / 1327 | time 800[s] | perplexity 114.29\n",
"| epoch 3 | iter 861 / 1327 | time 805[s] | perplexity 118.70\n",
"| epoch 3 | iter 881 / 1327 | time 809[s] | perplexity 106.59\n",
"| epoch 3 | iter 901 / 1327 | time 814[s] | perplexity 130.60\n",
"| epoch 3 | iter 921 / 1327 | time 818[s] | perplexity 119.56\n",
"| epoch 3 | iter 941 / 1327 | time 822[s] | perplexity 126.76\n",
"| epoch 3 | iter 961 / 1327 | time 827[s] | perplexity 132.62\n",
"| epoch 3 | iter 981 / 1327 | time 831[s] | perplexity 122.90\n",
"| epoch 3 | iter 1001 / 1327 | time 836[s] | perplexity 109.57\n",
"| epoch 3 | iter 1021 / 1327 | time 842[s] | perplexity 128.46\n",
"| epoch 3 | iter 1041 / 1327 | time 848[s] | perplexity 118.44\n",
"| epoch 3 | iter 1061 / 1327 | time 853[s] | perplexity 102.22\n",
"| epoch 3 | iter 1081 / 1327 | time 857[s] | perplexity 87.85\n",
"| epoch 3 | iter 1101 / 1327 | time 864[s] | perplexity 95.52\n",
"| epoch 3 | iter 1121 / 1327 | time 869[s] | perplexity 120.50\n",
"| epoch 3 | iter 1141 / 1327 | time 874[s] | perplexity 114.41\n",
"| epoch 3 | iter 1161 / 1327 | time 880[s] | perplexity 107.33\n",
"| epoch 3 | iter 1181 / 1327 | time 885[s] | perplexity 110.84\n",
"| epoch 3 | iter 1201 / 1327 | time 890[s] | perplexity 93.96\n",
"| epoch 3 | iter 1221 / 1327 | time 895[s] | perplexity 88.58\n",
"| epoch 3 | iter 1241 / 1327 | time 900[s] | perplexity 105.31\n",
"| epoch 3 | iter 1261 / 1327 | time 904[s] | perplexity 105.53\n",
"| epoch 3 | iter 1281 / 1327 | time 909[s] | perplexity 100.72\n",
"| epoch 3 | iter 1301 / 1327 | time 914[s] | perplexity 130.49\n",
"| epoch 3 | iter 1321 / 1327 | time 919[s] | perplexity 127.11\n",
"| epoch 4 | iter 1 / 1327 | time 920[s] | perplexity 132.73\n",
"| epoch 4 | iter 21 / 1327 | time 925[s] | perplexity 121.07\n",
"| epoch 4 | iter 41 / 1327 | time 929[s] | perplexity 106.88\n",
"| epoch 4 | iter 61 / 1327 | time 934[s] | perplexity 106.00\n",
"| epoch 4 | iter 81 / 1327 | time 939[s] | perplexity 95.54\n",
"| epoch 4 | iter 101 / 1327 | time 944[s] | perplexity 86.23\n",
"| epoch 4 | iter 121 / 1327 | time 949[s] | perplexity 94.56\n",
"| epoch 4 | iter 141 / 1327 | time 953[s] | perplexity 103.08\n",
"| epoch 4 | iter 161 / 1327 | time 958[s] | perplexity 118.39\n",
"| epoch 4 | iter 181 / 1327 | time 963[s] | perplexity 127.76\n",
"| epoch 4 | iter 201 / 1327 | time 967[s] | perplexity 119.95\n",
"| epoch 4 | iter 221 / 1327 | time 971[s] | perplexity 121.68\n",
"| epoch 4 | iter 241 / 1327 | time 976[s] | perplexity 114.74\n",
"| epoch 4 | iter 261 / 1327 | time 981[s] | perplexity 114.72\n",
"| epoch 4 | iter 281 / 1327 | time 985[s] | perplexity 120.67\n",
"| epoch 4 | iter 301 / 1327 | time 990[s] | perplexity 103.55\n",
"| epoch 4 | iter 321 / 1327 | time 995[s] | perplexity 83.56\n",
"| epoch 4 | iter 341 / 1327 | time 999[s] | perplexity 100.04\n",
"| epoch 4 | iter 361 / 1327 | time 1004[s] | perplexity 127.70\n",
"| epoch 4 | iter 381 / 1327 | time 1009[s] | perplexity 96.72\n",
"| epoch 4 | iter 401 / 1327 | time 1013[s] | perplexity 109.85\n",
"| epoch 4 | iter 421 / 1327 | time 1018[s] | perplexity 94.10\n",
"| epoch 4 | iter 441 / 1327 | time 1022[s] | perplexity 102.51\n",
"| epoch 4 | iter 461 / 1327 | time 1027[s] | perplexity 99.95\n",
"| epoch 4 | iter 481 / 1327 | time 1032[s] | perplexity 101.84\n",
"| epoch 4 | iter 501 / 1327 | time 1036[s] | perplexity 108.18\n",
"| epoch 4 | iter 521 / 1327 | time 1041[s] | perplexity 117.74\n",
"| epoch 4 | iter 541 / 1327 | time 1045[s] | perplexity 111.75\n",
"| epoch 4 | iter 561 / 1327 | time 1050[s] | perplexity 101.40\n",
"| epoch 4 | iter 581 / 1327 | time 1055[s] | perplexity 89.76\n",
"| epoch 4 | iter 601 / 1327 | time 1060[s] | perplexity 126.00\n",
"| epoch 4 | iter 621 / 1327 | time 1064[s] | perplexity 120.89\n",
"| epoch 4 | iter 641 / 1327 | time 1069[s] | perplexity 109.96\n",
"| epoch 4 | iter 661 / 1327 | time 1074[s] | perplexity 102.95\n",
"| epoch 4 | iter 681 / 1327 | time 1079[s] | perplexity 84.85\n",
"| epoch 4 | iter 701 / 1327 | time 1084[s] | perplexity 101.62\n",
"| epoch 4 | iter 721 / 1327 | time 1089[s] | perplexity 107.78\n",
"| epoch 4 | iter 741 / 1327 | time 1094[s] | perplexity 95.20\n",
"| epoch 4 | iter 761 / 1327 | time 1098[s] | perplexity 88.72\n",
"| epoch 4 | iter 781 / 1327 | time 1103[s] | perplexity 87.57\n",
"| epoch 4 | iter 801 / 1327 | time 1108[s] | perplexity 97.64\n",
"| epoch 4 | iter 821 / 1327 | time 1112[s] | perplexity 102.00\n",
"| epoch 4 | iter 841 / 1327 | time 1116[s] | perplexity 98.01\n",
"| epoch 4 | iter 861 / 1327 | time 1121[s] | perplexity 103.25\n",
"| epoch 4 | iter 881 / 1327 | time 1125[s] | perplexity 92.42\n",
"| epoch 4 | iter 901 / 1327 | time 1130[s] | perplexity 114.35\n",
"| epoch 4 | iter 921 / 1327 | time 1134[s] | perplexity 104.32\n",
"| epoch 4 | iter 941 / 1327 | time 1139[s] | perplexity 112.19\n",
"| epoch 4 | iter 961 / 1327 | time 1143[s] | perplexity 112.16\n",
"| epoch 4 | iter 981 / 1327 | time 1148[s] | perplexity 106.37\n",
"| epoch 4 | iter 1001 / 1327 | time 1152[s] | perplexity 97.07\n",
"| epoch 4 | iter 1021 / 1327 | time 1157[s] | perplexity 112.89\n",
"| epoch 4 | iter 1041 / 1327 | time 1161[s] | perplexity 103.65\n",
"| epoch 4 | iter 1061 / 1327 | time 1166[s] | perplexity 88.32\n",
"| epoch 4 | iter 1081 / 1327 | time 1170[s] | perplexity 77.75\n",
"| epoch 4 | iter 1101 / 1327 | time 1175[s] | perplexity 79.79\n",
"| epoch 4 | iter 1121 / 1327 | time 1179[s] | perplexity 102.84\n",
"| epoch 4 | iter 1141 / 1327 | time 1183[s] | perplexity 99.15\n",
"| epoch 4 | iter 1161 / 1327 | time 1188[s] | perplexity 91.90\n",
"| epoch 4 | iter 1181 / 1327 | time 1192[s] | perplexity 95.42\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 4 | iter 1201 / 1327 | time 1197[s] | perplexity 83.07\n",
"| epoch 4 | iter 1221 / 1327 | time 1201[s] | perplexity 76.08\n",
"| epoch 4 | iter 1241 / 1327 | time 1206[s] | perplexity 91.78\n",
"| epoch 4 | iter 1261 / 1327 | time 1210[s] | perplexity 94.18\n",
"| epoch 4 | iter 1281 / 1327 | time 1215[s] | perplexity 88.94\n",
"| epoch 4 | iter 1301 / 1327 | time 1219[s] | perplexity 111.88\n",
"| epoch 4 | iter 1321 / 1327 | time 1223[s] | perplexity 110.80\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"evaluating perplexity ...\n",
"234 / 235\n",
"test perplexity: 135.81750561235523\n"
]
}
],
"source": [
"# coding: utf-8\n",
"import sys\n",
"sys.path.append('..')\n",
"from common.optimizer import SGD\n",
"from common.trainer import RnnlmTrainer\n",
"from common.util import eval_perplexity\n",
"from dataset import ptb\n",
"from rnnlm import Rnnlm\n",
"\n",
"\n",
"# ハイパーパラメータの設定\n",
"batch_size = 20\n",
"wordvec_size = 100\n",
"hidden_size = 100 # RNNの隠れ状態ベクトルの要素数\n",
"time_size = 35 # RNNを展開するサイズ\n",
"lr = 20.0\n",
"max_epoch = 4\n",
"max_grad = 0.25\n",
"\n",
"# 学習データの読み込み\n",
"corpus, word_to_id, id_to_word = ptb.load_data('train')\n",
"corpus_test, _, _ = ptb.load_data('test')\n",
"vocab_size = len(word_to_id)\n",
"xs = corpus[:-1]\n",
"ts = corpus[1:]\n",
"\n",
"# モデルの生成\n",
"model = Rnnlm(vocab_size, wordvec_size, hidden_size)\n",
"optimizer = SGD(lr)\n",
"trainer = RnnlmTrainer(model, optimizer)\n",
"\n",
"# 勾配クリッピングを適用して学習\n",
"trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad,\n",
" eval_interval=20)\n",
"trainer.plot(ylim=(0, 500))\n",
"\n",
"# テストデータで評価\n",
"model.reset_state()\n",
"ppl_test = eval_perplexity(model, corpus_test)\n",
"print('test perplexity: ', ppl_test)\n",
"\n",
"# パラメータの保存\n",
"model.save_params()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.5 RNNLMのさらなる改善\n",
"\n",
"- LSTIMレイヤの多層化\n",
"- Dropoutによる過学習の抑制\n",
"- 重み共有"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 1 | iter 1 / 1327 | time 8[s] | perplexity 10000.15\n",
"| epoch 1 | iter 21 / 1327 | time 93[s] | perplexity 4234.82\n",
"| epoch 1 | iter 41 / 1327 | time 178[s] | perplexity 1896.71\n",
"| epoch 1 | iter 61 / 1327 | time 272[s] | perplexity 1280.45\n",
"| epoch 1 | iter 81 / 1327 | time 332[s] | perplexity 1023.16\n",
"| epoch 1 | iter 101 / 1327 | time 394[s] | perplexity 831.45\n",
"| epoch 1 | iter 121 / 1327 | time 458[s] | perplexity 807.89\n",
"| epoch 1 | iter 141 / 1327 | time 522[s] | perplexity 720.43\n",
"| epoch 1 | iter 161 / 1327 | time 587[s] | perplexity 689.12\n",
"| epoch 1 | iter 181 / 1327 | time 651[s] | perplexity 679.70\n",
"| epoch 1 | iter 201 / 1327 | time 716[s] | perplexity 602.49\n",
"| epoch 1 | iter 221 / 1327 | time 781[s] | perplexity 567.63\n",
"| epoch 1 | iter 241 / 1327 | time 844[s] | perplexity 528.17\n",
"| epoch 1 | iter 261 / 1327 | time 909[s] | perplexity 538.42\n",
"| epoch 1 | iter 281 / 1327 | time 975[s] | perplexity 521.44\n",
"| epoch 1 | iter 301 / 1327 | time 1043[s] | perplexity 449.47\n",
"| epoch 1 | iter 321 / 1327 | time 1112[s] | perplexity 399.01\n",
"| epoch 1 | iter 341 / 1327 | time 1180[s] | perplexity 452.80\n",
"| epoch 1 | iter 361 / 1327 | time 1247[s] | perplexity 460.58\n",
"| epoch 1 | iter 381 / 1327 | time 1312[s] | perplexity 383.68\n",
"| epoch 1 | iter 401 / 1327 | time 1379[s] | perplexity 404.59\n",
"| epoch 1 | iter 421 / 1327 | time 1445[s] | perplexity 394.29\n",
"| epoch 1 | iter 441 / 1327 | time 1509[s] | perplexity 375.40\n",
"| epoch 1 | iter 461 / 1327 | time 1574[s] | perplexity 373.20\n",
"| epoch 1 | iter 481 / 1327 | time 1639[s] | perplexity 344.29\n",
"| epoch 1 | iter 501 / 1327 | time 1682[s] | perplexity 355.09\n",
"| epoch 1 | iter 521 / 1327 | time 1722[s] | perplexity 345.81\n",
"| epoch 1 | iter 541 / 1327 | time 1761[s] | perplexity 364.67\n",
"| epoch 1 | iter 561 / 1327 | time 1798[s] | perplexity 323.76\n",
"| epoch 1 | iter 581 / 1327 | time 1836[s] | perplexity 293.38\n",
"| epoch 1 | iter 601 / 1327 | time 1873[s] | perplexity 377.00\n",
"| epoch 1 | iter 621 / 1327 | time 1911[s] | perplexity 346.04\n",
"| epoch 1 | iter 641 / 1327 | time 1948[s] | perplexity 315.85\n",
"| epoch 1 | iter 661 / 1327 | time 1985[s] | perplexity 307.43\n",
"| epoch 1 | iter 681 / 1327 | time 2023[s] | perplexity 257.17\n",
"| epoch 1 | iter 701 / 1327 | time 2062[s] | perplexity 281.79\n",
"| epoch 1 | iter 721 / 1327 | time 2101[s] | perplexity 289.17\n",
"| epoch 1 | iter 741 / 1327 | time 2140[s] | perplexity 249.45\n",
"| epoch 1 | iter 761 / 1327 | time 2179[s] | perplexity 258.53\n",
"| epoch 1 | iter 781 / 1327 | time 2216[s] | perplexity 244.86\n",
"| epoch 1 | iter 801 / 1327 | time 2254[s] | perplexity 269.02\n",
"| epoch 1 | iter 821 / 1327 | time 2291[s] | perplexity 248.95\n",
"| epoch 1 | iter 841 / 1327 | time 2328[s] | perplexity 254.58\n",
"| epoch 1 | iter 861 / 1327 | time 2367[s] | perplexity 249.46\n",
"| epoch 1 | iter 881 / 1327 | time 2405[s] | perplexity 230.01\n",
"| epoch 1 | iter 901 / 1327 | time 2442[s] | perplexity 280.32\n",
"| epoch 1 | iter 921 / 1327 | time 2480[s] | perplexity 253.94\n",
"| epoch 1 | iter 941 / 1327 | time 2516[s] | perplexity 257.12\n",
"| epoch 1 | iter 961 / 1327 | time 2553[s] | perplexity 275.56\n",
"| epoch 1 | iter 981 / 1327 | time 2590[s] | perplexity 256.14\n",
"| epoch 1 | iter 1001 / 1327 | time 2627[s] | perplexity 215.38\n",
"| epoch 1 | iter 1021 / 1327 | time 2664[s] | perplexity 251.41\n",
"| epoch 1 | iter 1041 / 1327 | time 2702[s] | perplexity 228.57\n",
"| epoch 1 | iter 1061 / 1327 | time 2739[s] | perplexity 218.85\n",
"| epoch 1 | iter 1081 / 1327 | time 2776[s] | perplexity 188.14\n",
"| epoch 1 | iter 1101 / 1327 | time 2813[s] | perplexity 215.89\n",
"| epoch 1 | iter 1121 / 1327 | time 2854[s] | perplexity 255.60\n",
"| epoch 1 | iter 1141 / 1327 | time 2897[s] | perplexity 229.44\n",
"| epoch 1 | iter 1161 / 1327 | time 2941[s] | perplexity 221.68\n",
"| epoch 1 | iter 1181 / 1327 | time 2981[s] | perplexity 210.42\n",
"| epoch 1 | iter 1201 / 1327 | time 3020[s] | perplexity 181.10\n",
"| epoch 1 | iter 1221 / 1327 | time 3065[s] | perplexity 177.75\n",
"| epoch 1 | iter 1241 / 1327 | time 3110[s] | perplexity 209.31\n",
"| epoch 1 | iter 1261 / 1327 | time 3154[s] | perplexity 191.01\n",
"| epoch 1 | iter 1281 / 1327 | time 3203[s] | perplexity 199.19\n",
"| epoch 1 | iter 1301 / 1327 | time 3251[s] | perplexity 246.91\n",
"| epoch 1 | iter 1321 / 1327 | time 3299[s] | perplexity 234.73\n",
"evaluating perplexity ...\n",
"209 / 210\n",
"valid perplexity: 196.80691468962846\n",
"--------------------------------------------------\n",
"| epoch 2 | iter 1 / 1327 | time 2[s] | perplexity 291.16\n",
"| epoch 2 | iter 21 / 1327 | time 46[s] | perplexity 230.42\n",
"| epoch 2 | iter 41 / 1327 | time 90[s] | perplexity 210.98\n",
"| epoch 2 | iter 61 / 1327 | time 137[s] | perplexity 195.20\n",
"| epoch 2 | iter 81 / 1327 | time 182[s] | perplexity 179.33\n",
"| epoch 2 | iter 101 / 1327 | time 234[s] | perplexity 168.96\n",
"| epoch 2 | iter 121 / 1327 | time 318[s] | perplexity 179.30\n",
"| epoch 2 | iter 141 / 1327 | time 366[s] | perplexity 198.83\n",
"| epoch 2 | iter 161 / 1327 | time 411[s] | perplexity 215.86\n",
"| epoch 2 | iter 181 / 1327 | time 456[s] | perplexity 222.62\n",
"| epoch 2 | iter 201 / 1327 | time 501[s] | perplexity 206.98\n",
"| epoch 2 | iter 221 / 1327 | time 547[s] | perplexity 204.71\n",
"| epoch 2 | iter 241 / 1327 | time 625[s] | perplexity 197.85\n",
"| epoch 2 | iter 261 / 1327 | time 702[s] | perplexity 213.96\n",
"| epoch 2 | iter 281 / 1327 | time 755[s] | perplexity 205.79\n",
"| epoch 2 | iter 301 / 1327 | time 808[s] | perplexity 186.50\n",
"| epoch 2 | iter 321 / 1327 | time 859[s] | perplexity 152.48\n",
"| epoch 2 | iter 341 / 1327 | time 907[s] | perplexity 198.20\n",
"| epoch 2 | iter 361 / 1327 | time 957[s] | perplexity 215.05\n",
"| epoch 2 | iter 381 / 1327 | time 1006[s] | perplexity 170.60\n",
"| epoch 2 | iter 401 / 1327 | time 1067[s] | perplexity 193.21\n",
"| epoch 2 | iter 421 / 1327 | time 1135[s] | perplexity 176.80\n",
"| epoch 2 | iter 441 / 1327 | time 1201[s] | perplexity 180.86\n",
"| epoch 2 | iter 461 / 1327 | time 1267[s] | perplexity 182.34\n",
"| epoch 2 | iter 481 / 1327 | time 1332[s] | perplexity 174.90\n",
"| epoch 2 | iter 501 / 1327 | time 1381[s] | perplexity 192.26\n",
"| epoch 2 | iter 521 / 1327 | time 1449[s] | perplexity 189.70\n",
"| epoch 2 | iter 541 / 1327 | time 1541[s] | perplexity 202.00\n",
"| epoch 2 | iter 561 / 1327 | time 1641[s] | perplexity 171.94\n",
"| epoch 2 | iter 581 / 1327 | time 1726[s] | perplexity 158.15\n",
"| epoch 2 | iter 601 / 1327 | time 1818[s] | perplexity 215.57\n",
"| epoch 2 | iter 621 / 1327 | time 1908[s] | perplexity 201.54\n",
"| epoch 2 | iter 641 / 1327 | time 2002[s] | perplexity 185.20\n",
"| epoch 2 | iter 661 / 1327 | time 2090[s] | perplexity 174.48\n",
"| epoch 2 | iter 681 / 1327 | time 2174[s] | perplexity 146.61\n",
"| epoch 2 | iter 701 / 1327 | time 2228[s] | perplexity 171.08\n",
"| epoch 2 | iter 721 / 1327 | time 2284[s] | perplexity 176.13\n",
"| epoch 2 | iter 741 / 1327 | time 2349[s] | perplexity 150.93\n",
"| epoch 2 | iter 761 / 1327 | time 2402[s] | perplexity 149.54\n",
"| epoch 2 | iter 781 / 1327 | time 2482[s] | perplexity 149.39\n",
"| epoch 2 | iter 801 / 1327 | time 2536[s] | perplexity 168.56\n",
"| epoch 2 | iter 821 / 1327 | time 2626[s] | perplexity 161.91\n",
"| epoch 2 | iter 841 / 1327 | time 2680[s] | perplexity 165.24\n",
"| epoch 2 | iter 861 / 1327 | time 2735[s] | perplexity 161.05\n",
"| epoch 2 | iter 881 / 1327 | time 2787[s] | perplexity 149.73\n",
"| epoch 2 | iter 901 / 1327 | time 2843[s] | perplexity 188.35\n",
"| epoch 2 | iter 921 / 1327 | time 2922[s] | perplexity 165.84\n",
"| epoch 2 | iter 941 / 1327 | time 2982[s] | perplexity 168.55\n",
"| epoch 2 | iter 961 / 1327 | time 3047[s] | perplexity 186.29\n",
"| epoch 2 | iter 981 / 1327 | time 3108[s] | perplexity 174.66\n",
"| epoch 2 | iter 1001 / 1327 | time 3162[s] | perplexity 149.82\n",
"| epoch 2 | iter 1021 / 1327 | time 3223[s] | perplexity 174.38\n",
"| epoch 2 | iter 1041 / 1327 | time 3286[s] | perplexity 157.34\n",
"| epoch 2 | iter 1061 / 1327 | time 3348[s] | perplexity 148.87\n",
"| epoch 2 | iter 1081 / 1327 | time 3407[s] | perplexity 123.84\n",
"| epoch 2 | iter 1101 / 1327 | time 3458[s] | perplexity 137.86\n",
"| epoch 2 | iter 1121 / 1327 | time 3511[s] | perplexity 172.72\n",
"| epoch 2 | iter 1141 / 1327 | time 3566[s] | perplexity 164.43\n",
"| epoch 2 | iter 1161 / 1327 | time 3616[s] | perplexity 147.04\n",
"| epoch 2 | iter 1181 / 1327 | time 3666[s] | perplexity 148.45\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"| epoch 2 | iter 1201 / 1327 | time 3715[s] | perplexity 127.46\n",
"| epoch 2 | iter 1221 / 1327 | time 3769[s] | perplexity 125.83\n",
"| epoch 2 | iter 1241 / 1327 | time 3822[s] | perplexity 146.89\n",
"| epoch 2 | iter 1261 / 1327 | time 3871[s] | perplexity 137.78\n",
"| epoch 2 | iter 1281 / 1327 | time 3923[s] | perplexity 140.56\n",
"| epoch 2 | iter 1301 / 1327 | time 3973[s] | perplexity 178.40\n",
"| epoch 2 | iter 1321 / 1327 | time 4023[s] | perplexity 171.81\n",
"evaluating perplexity ...\n",
"209 / 210\n",
"valid perplexity: 145.6046890226078\n",
"--------------------------------------------------\n",
"| epoch 3 | iter 1 / 1327 | time 6[s] | perplexity 222.54\n",
"| epoch 3 | iter 21 / 1327 | time 124[s] | perplexity 161.39\n",
"| epoch 3 | iter 41 / 1327 | time 243[s] | perplexity 152.21\n",
"| epoch 3 | iter 61 / 1327 | time 295[s] | perplexity 142.97\n",
"| epoch 3 | iter 81 / 1327 | time 344[s] | perplexity 128.21\n",
"| epoch 3 | iter 101 / 1327 | time 393[s] | perplexity 122.36\n",
"| epoch 3 | iter 121 / 1327 | time 441[s] | perplexity 133.44\n",
"| epoch 3 | iter 141 / 1327 | time 492[s] | perplexity 146.84\n",
"| epoch 3 | iter 161 / 1327 | time 541[s] | perplexity 162.45\n",
"| epoch 3 | iter 181 / 1327 | time 596[s] | perplexity 169.60\n",
"| epoch 3 | iter 201 / 1327 | time 647[s] | perplexity 157.57\n",
"| epoch 3 | iter 221 / 1327 | time 695[s] | perplexity 155.70\n",
"| epoch 3 | iter 241 / 1327 | time 742[s] | perplexity 152.59\n",
"| epoch 3 | iter 261 / 1327 | time 797[s] | perplexity 163.07\n",
"| epoch 3 | iter 281 / 1327 | time 846[s] | perplexity 156.43\n",
"| epoch 3 | iter 301 / 1327 | time 906[s] | perplexity 138.56\n",
"| epoch 3 | iter 321 / 1327 | time 970[s] | perplexity 112.04\n",
"| epoch 3 | iter 341 / 1327 | time 1036[s] | perplexity 151.57\n",
"| epoch 3 | iter 361 / 1327 | time 1098[s] | perplexity 165.21\n",
"| epoch 3 | iter 381 / 1327 | time 1159[s] | perplexity 131.16\n",
"| epoch 3 | iter 401 / 1327 | time 1215[s] | perplexity 149.15\n",
"| epoch 3 | iter 421 / 1327 | time 1264[s] | perplexity 130.95\n",
"| epoch 3 | iter 441 / 1327 | time 1322[s] | perplexity 140.28\n",
"| epoch 3 | iter 461 / 1327 | time 1375[s] | perplexity 138.78\n",
"| epoch 3 | iter 481 / 1327 | time 1429[s] | perplexity 134.81\n",
"| epoch 3 | iter 501 / 1327 | time 1480[s] | perplexity 149.98\n",
"| epoch 3 | iter 521 / 1327 | time 1534[s] | perplexity 152.35\n",
"| epoch 3 | iter 541 / 1327 | time 1583[s] | perplexity 157.25\n",
"| epoch 3 | iter 561 / 1327 | time 1633[s] | perplexity 134.19\n",
"| epoch 3 | iter 581 / 1327 | time 1682[s] | perplexity 123.06\n",
"| epoch 3 | iter 601 / 1327 | time 1735[s] | perplexity 169.27\n",
"| epoch 3 | iter 621 / 1327 | time 1788[s] | perplexity 160.02\n",
"| epoch 3 | iter 641 / 1327 | time 1838[s] | perplexity 146.95\n",
"| epoch 3 | iter 661 / 1327 | time 1888[s] | perplexity 136.56\n",
"| epoch 3 | iter 681 / 1327 | time 1939[s] | perplexity 117.16\n",
"| epoch 3 | iter 701 / 1327 | time 1991[s] | perplexity 136.33\n",
"| epoch 3 | iter 721 / 1327 | time 2043[s] | perplexity 139.97\n",
"| epoch 3 | iter 741 / 1327 | time 2094[s] | perplexity 121.00\n",
"| epoch 3 | iter 761 / 1327 | time 2154[s] | perplexity 115.03\n",
"| epoch 3 | iter 781 / 1327 | time 2216[s] | perplexity 122.38\n",
"| epoch 3 | iter 801 / 1327 | time 2270[s] | perplexity 135.23\n",
"| epoch 3 | iter 821 / 1327 | time 2326[s] | perplexity 133.33\n",
"| epoch 3 | iter 841 / 1327 | time 2405[s] | perplexity 134.35\n",
"| epoch 3 | iter 861 / 1327 | time 2457[s] | perplexity 130.02\n"
]
}
],
"source": [
"# coding: utf-8\n",
"import sys\n",
"sys.path.append('..')\n",
"from common import config\n",
"# GPUで実行する場合は下記のコメントアウトを消去(要cupy)\n",
"# ==============================================\n",
"# config.GPU = True\n",
"# ==============================================\n",
"from common.optimizer import SGD\n",
"from common.trainer import RnnlmTrainer\n",
"from common.util import eval_perplexity, to_gpu\n",
"from dataset import ptb\n",
"from better_rnnlm import BetterRnnlm\n",
"\n",
"\n",
"# ハイパーパラメータの設定\n",
"batch_size = 20\n",
"wordvec_size = 650\n",
"hidden_size = 650\n",
"time_size = 35\n",
"lr = 20.0\n",
"max_epoch = 40\n",
"max_grad = 0.25\n",
"dropout = 0.5\n",
"\n",
"# 学習データの読み込み\n",
"corpus, word_to_id, id_to_word = ptb.load_data('train')\n",
"corpus_val, _, _ = ptb.load_data('val')\n",
"corpus_test, _, _ = ptb.load_data('test')\n",
"\n",
"if config.GPU:\n",
" corpus = to_gpu(corpus)\n",
" corpus_val = to_gpu(corpus_val)\n",
" corpus_test = to_gpu(corpus_test)\n",
"\n",
"vocab_size = len(word_to_id)\n",
"xs = corpus[:-1]\n",
"ts = corpus[1:]\n",
"\n",
"model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout)\n",
"optimizer = SGD(lr)\n",
"trainer = RnnlmTrainer(model, optimizer)\n",
"\n",
"best_ppl = float('inf')\n",
"for epoch in range(max_epoch):\n",
" trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size,\n",
" time_size=time_size, max_grad=max_grad)\n",
"\n",
" model.reset_state()\n",
" ppl = eval_perplexity(model, corpus_val)\n",
" print('valid perplexity: ', ppl)\n",
"\n",
" if best_ppl > ppl:\n",
" best_ppl = ppl\n",
" model.save_params()\n",
" else:\n",
" lr /= 4.0\n",
" optimizer.lr = lr\n",
"\n",
" model.reset_state()\n",
" print('-' * 50)\n",
"\n",
"\n",
"# テストデータでの評価\n",
"model.reset_state()\n",
"ppl_test = eval_perplexity(model, corpus_test)\n",
"print('test perplexity: ', ppl_test)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.6 まとめ\n",
"\n",
"- 単純なRNNの勾配喪失・勾配爆発の問題を解決するために、以下のテクニックを導入する。\n",
" - 勾配クリッピング\n",
" - LSTM/GRUなどのゲート付きRNN\n",
"- LSTMに使用されるゲート\n",
" - inputゲート\n",
" - forgetゲート\n",
" - outputゲート\n",
" - それぞれに重みがついている。sigmoid関数で0.0から1.0までの値が使われる。\n",
"- 言語モデルの実装\n",
" - LSTMレイヤの多層化\n",
" - Dropout\n",
" - 重み共有\n",
"- RNNの正則化は重要なテーマであり、Dropoutベースの様々な手法が提案されている。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment