Created
September 13, 2020 14:59
-
-
Save CookieBox26/a4391ccc77d54c03e59ba4b3d7e14c47 to your computer and use it in GitHub Desktop.
LSTM / GRU で Sequential MNIST を学習する(GRU の学習の確認)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# LSTM / GRU で Sequential MNIST を学習する(GRU の学習の確認)\n", | |
"\n", | |
"関連Gistの 4. で GRU が学習できることの確認をしていなかったので確認しているだけです。1~3つ目のセルは関連Gistの 4. の「データ生成」「GRUの定義」「学習」のコードと同じです(3セル目の ★ のところは修正しました)。\n", | |
"\n", | |
"原論文では GRU の隠れ状態の次元数は LSTM よりゲートが少ない分パラメータを増やすはずですが(パラメータ数をそろえて比較するため)、その調整はさぼっていて LSTM の隠れ状態の次元数のままになっています。\n", | |
"\n", | |
"### 参考文献\n", | |
"1. [[1803.01271]An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271) ;TCNの原論文。\n", | |
"1. [locuslab/TCN: Sequence modeling benchmarks and temporal convolutional networks](https://github.com/locuslab/TCN) ;TCNの原論文のリポジトリ。\n", | |
"\n", | |
"### 関連Gist\n", | |
"1. [ニューラルネットで足し算する(Temporal Convolutional Network )](https://gist.github.com/CookieBox26/8e314f1164d7d5beea5312d625115fed); TCN による足し算タスク。\n", | |
"2. [LSTM で足し算する](https://gist.github.com/CookieBox26/31a1247c0e31d6109067229a151ead66); LSTM による足し算タスク。\n", | |
"3. [TCN で Sequential MNIST を学習する](https://gist.github.com/CookieBox26/831d03037141f852f9a47a10a9eb4780); TCN による Sequential MNIST タスク。\n", | |
"4. [LSTM / GRU で Sequential MNIST を学習する](https://gist.github.com/CookieBox26/64202c231124f9cc667a04fbd2d2b4a6); LSTM による Sequential MNIST タスク。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ 元々の1バッチのサイズ\n", | |
"torch.Size([64, 1, 28, 28])\n", | |
"◆ 系列データ化後の1バッチのサイズ\n", | |
"torch.Size([64, 1, 784])\n", | |
"◆ さらにLSTM向けに転置した後の1バッチのサイズ\n", | |
"torch.Size([784, 64, 1])\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from torchvision import datasets, transforms\n", | |
"\n", | |
"\n", | |
"# MNIST をロードする関数\n", | |
"# https://github.com/locuslab/TCN/blob/master/TCN/mnist_pixel/utils.py より\n", | |
"def data_generator(root, batch_size):\n", | |
" train_set = datasets.MNIST(root=root, train=True, download=True,\n", | |
" transform=transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" # (0.1307,), (0.3081,) は MNIST の訓練データの平均と標準偏差らしい\n", | |
" transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ]))\n", | |
" test_set = datasets.MNIST(root=root, train=False, download=True,\n", | |
" transform=transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" # (0.1307,), (0.3081,) は MNIST の訓練データの平均と標準偏差らしい\n", | |
" transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ]))\n", | |
" train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)\n", | |
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)\n", | |
" return train_loader, test_loader\n", | |
"\n", | |
"# TCN向けのデータをLSTM向けに転置する関数\n", | |
"# in: torch.Size([batch_size, input_channels, seq_length])\n", | |
"# out: torch.Size([seq_length, batch_size, input_channels])\n", | |
"def view_for_lstm(x):\n", | |
" return x.transpose(0, 2).transpose(1, 2).contiguous()\n", | |
"\n", | |
"root = '../data'\n", | |
"batch_size = 64\n", | |
"train_loader, test_loader = data_generator(root, batch_size)\n", | |
"for batch_idx, (data, target) in enumerate(train_loader):\n", | |
" print('◆ 元々の1バッチのサイズ')\n", | |
" print(data.size())\n", | |
" data = data.view(-1, 1, 784)\n", | |
" print('◆ 系列データ化後の1バッチのサイズ')\n", | |
" print(data.size())\n", | |
" print('◆ さらにLSTM向けに転置した後の1バッチのサイズ')\n", | |
" data = view_for_lstm(data)\n", | |
" print(data.size())\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ モデル\n", | |
"GRU(\n", | |
" (network): Sequential(\n", | |
" (gru): GRU(1, 130)\n", | |
" (linear): Linear(in_features=130, out_features=10, bias=True)\n", | |
" )\n", | |
")\n", | |
"\n", | |
"◆ モデルに Sequential MNIST を流してみる\n", | |
"入力 torch.Size([784, 64, 1])\n", | |
"出力特徴(初期値) torch.Size([1, 64, 130])\n", | |
"========== forward ==========\n", | |
"入力 torch.Size([784, 64, 1])\n", | |
"GRU層の出力 torch.Size([784, 64, 130])\n", | |
"出力特徴 torch.Size([1, 64, 130])\n", | |
"出力 torch.Size([64, 10])\n", | |
"=============================\n", | |
"出力 torch.Size([64, 10])\n", | |
"出力特徴(784ステップ後) torch.Size([1, 64, 130])\n", | |
"モデルの出力\n", | |
"tensor([-2.3676, -2.1991, -2.3292, -2.3177, -2.2764, -2.3562, -2.2945, -2.2423,\n", | |
" -2.3422, -2.3132], grad_fn=<SelectBackward>)\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"import torch.nn as nn\n", | |
"from collections import OrderedDict\n", | |
"\n", | |
"class GRU(nn.Module):\n", | |
" def __init__(self,\n", | |
" input_size=1, # Sequential MNIST タスクなので入力は1次元\n", | |
" output_size=10, # Sequential MNIST タスクなので出力は10次元\n", | |
" num_layers=1, # GRU ブロックの積み重ね数が1\n", | |
" d_hidden=130, # GRU ブロックの出力次元数が130\n", | |
" initial_update_gate_bias=0.5, # 更新ゲートのバイアスの初期値\n", | |
" dropout=0.0):\n", | |
" super(GRU, self).__init__()\n", | |
" self.num_layers = num_layers\n", | |
" self.d_hidden = d_hidden\n", | |
" self.layers = OrderedDict()\n", | |
" self.layers['gru'] = nn.GRU(input_size, \n", | |
" d_hidden,\n", | |
" num_layers=num_layers,\n", | |
" dropout=dropout)\n", | |
" self.layers['linear'] = nn.Linear(d_hidden, output_size)\n", | |
" self.network = nn.Sequential(self.layers)\n", | |
" self.init_weights(initial_update_gate_bias)\n", | |
"\n", | |
" def init_weights(self, initial_update_gate_bias):\n", | |
" # 更新ゲートのバイアスの初期値をセット\n", | |
" for i_layer in range(self.num_layers):\n", | |
" bias = getattr(self.layers['gru'], f'bias_ih_l{i_layer}')\n", | |
" bias.data[self.d_hidden:(2*self.d_hidden)] = initial_update_gate_bias\n", | |
" bias = getattr(self.layers['gru'], f'bias_hh_l{i_layer}')\n", | |
" bias.data[self.d_hidden:(2*self.d_hidden)] = initial_update_gate_bias\n", | |
" self.layers['linear'].weight.data.normal_(0, 0.01)\n", | |
"\n", | |
" def forward(self, x, hidden, debug=False):\n", | |
" if debug: print('========== forward ==========')\n", | |
" if debug: print('入力', x.size())\n", | |
" out, hidden = self.layers['gru'](x, hidden)\n", | |
" if debug: print('GRU層の出力', out.size())\n", | |
" if debug: print('出力特徴', hidden.size())\n", | |
" x = self.layers['linear'](hidden[-1,:,:])\n", | |
" if debug: print('出力', x.size())\n", | |
" if debug: print('=============================')\n", | |
" return F.log_softmax(x, dim=1), hidden\n", | |
" \n", | |
" # バッチサイズを渡すと出力特徴の初期テンソルをつくってくれる\n", | |
" def generate_initial_hidden(self, batch_size):\n", | |
" return torch.zeros([self.num_layers, batch_size, self.d_hidden])\n", | |
"\n", | |
"model = GRU()\n", | |
"print('◆ モデル')\n", | |
"print(model)\n", | |
"\n", | |
"print('\\n◆ モデルに Sequential MNIST を流してみる')\n", | |
"hidden = model.generate_initial_hidden(batch_size)\n", | |
"print('入力', data.size())\n", | |
"print('出力特徴(初期値)', hidden.size())\n", | |
"out, hidden = model.forward(data, hidden, debug=True)\n", | |
"print('出力', out.size())\n", | |
"print('出力特徴(784ステップ後)', hidden.size())\n", | |
"print('モデルの出力')\n", | |
"print(out[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train Epoch: 1 [1280/60000 (2%)]\tLoss: 2.421090\tSteps: 16464\n", | |
"Train Epoch: 1 [2560/60000 (4%)]\tLoss: 2.303616\tSteps: 32144\n", | |
"Train Epoch: 1 [3840/60000 (6%)]\tLoss: 2.303360\tSteps: 47824\n", | |
"Train Epoch: 1 [5120/60000 (9%)]\tLoss: 2.299382\tSteps: 63504\n", | |
"Train Epoch: 1 [6400/60000 (11%)]\tLoss: 2.294239\tSteps: 79184\n", | |
"\n", | |
"Test set: Average loss: 2.2810, Accuracy: 1261/10000 (13%)\n", | |
"\n", | |
"Train Epoch: 2 [1280/60000 (2%)]\tLoss: 2.307144\tSteps: 95648\n", | |
"Train Epoch: 2 [2560/60000 (4%)]\tLoss: 2.074566\tSteps: 111328\n" | |
] | |
} | |
], | |
"source": [ | |
"# https://github.com/locuslab/TCN/blob/master/TCN/mnist_pixel/pmnist_test.py を LSTM 用にカスタマイズ.\n", | |
"\n", | |
"from torch.autograd import Variable\n", | |
"import torch.optim as optim\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"\n", | |
"root = '../data'\n", | |
"batch_size = 64\n", | |
"train_loader, test_loader = data_generator(root, batch_size)\n", | |
"\n", | |
"model = GRU()\n", | |
"optimizer = optim.RMSprop(model.parameters(), lr=1e-3) # RMSprop\n", | |
"\n", | |
"input_channels = 1\n", | |
"seq_length = 784\n", | |
"epochs = 20\n", | |
"log_interval = 20\n", | |
"\n", | |
"steps = 0\n", | |
"\n", | |
"def train(ep):\n", | |
" global steps\n", | |
" train_loss = 0\n", | |
" model.train()\n", | |
" for batch_idx, (data, target) in enumerate(train_loader):\n", | |
" data = data.view(-1, input_channels, seq_length)\n", | |
" data = view_for_lstm(data)\n", | |
" data, target = Variable(data), Variable(target)\n", | |
" optimizer.zero_grad()\n", | |
" hidden = model.generate_initial_hidden(data.size()[1]) # hidden: batch_size の端数が出うるので data.size()[1] をとる ★\n", | |
" output, hidden = model(data, hidden)\n", | |
" loss = F.nll_loss(output, target)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Grad Clip\n", | |
" train_loss += loss\n", | |
" steps += seq_length\n", | |
" if batch_idx > 0 and batch_idx % log_interval == 0:\n", | |
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tSteps: {}'.format(\n", | |
" ep, batch_idx * batch_size, len(train_loader.dataset),\n", | |
" 100. * batch_idx / len(train_loader), train_loss.item()/log_interval, steps))\n", | |
" train_loss = 0\n", | |
" \n", | |
" # 動作確認のため1エポック目をわざと早く終わらせる\n", | |
" if (ep == 1) and (batch_idx == 100):\n", | |
" break\n", | |
" \n", | |
"def test():\n", | |
" model.eval()\n", | |
" test_loss = 0\n", | |
" correct = 0\n", | |
" with torch.no_grad():\n", | |
" for data, target in test_loader:\n", | |
" data = data.view(-1, input_channels, seq_length)\n", | |
" data = view_for_lstm(data)\n", | |
" data, target = Variable(data), Variable(target)\n", | |
" hidden = model.generate_initial_hidden(data.size()[1]) # hidden: batch_size の端数が出うるので data.size()[1] をとる\n", | |
" output, hidden = model.forward(data, hidden)\n", | |
" test_loss += F.nll_loss(output, target, reduction='sum').item()\n", | |
" pred = output.data.max(1, keepdim=True)[1]\n", | |
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", | |
"\n", | |
" test_loss /= len(test_loader.dataset)\n", | |
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | |
" test_loss, correct, len(test_loader.dataset),\n", | |
" 100. * correct / len(test_loader.dataset)))\n", | |
" return test_loss\n", | |
"\n", | |
"\n", | |
"# 実行\n", | |
"for epoch in range(1, epochs + 1):\n", | |
" train(epoch)\n", | |
" test()\n", | |
" if epoch % 10 == 0:\n", | |
" lr /= 10\n", | |
" for param_group in optimizer.param_groups:\n", | |
" param_group['lr'] = lr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment