Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created March 26, 2020 13:53
Show Gist options
  • Save CookieBox26/1897e101f429ef3691293c4122d35c59 to your computer and use it in GitHub Desktop.
Save CookieBox26/1897e101f429ef3691293c4122d35c59 to your computer and use it in GitHub Desktop.
PyTorch でモデルを定義して勾配方向にパラメータを更新してみる
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch でモデルを定義して勾配方向にパラメータを更新してみる\n",
"\n",
"参考文献: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html\n",
"\n",
"- https://pytorch.org/docs/stable/nn.html#conv2d\n",
"\n",
"### 【参考】conv2d の入力次元数と出力次元数の関係\n",
"![conv2d の入力次元数と出力次元数の関係](https://pbs.twimg.com/media/EUB9HjKU0AAi5mc?format=jpg&name=large)\n",
"\n",
"### 【参考】max_pool2d ではプールサイズを敷き詰めたとき余った端の領域は無視される\n",
"![max_pool2d ではプールサイズを敷き詰めたとき余った端の領域は無視される](https://pbs.twimg.com/media/EUCNGd1U4AAhDg9?format=jpg&name=large)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 今回のネットワーク構造を流れるデータの次元がどうなっていくかの確認"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 1, 32, 32])\n",
"torch.Size([1, 6, 30, 30])\n",
"torch.Size([1, 6, 15, 15])\n",
"torch.Size([1, 16, 13, 13])\n",
"torch.Size([1, 16, 6, 6])\n",
"torch.Size([1, 576])\n",
"torch.Size([1, 120])\n",
"torch.Size([1, 84])\n",
"torch.Size([1, 10])\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"x = torch.randn(1, 1, 32, 32)\n",
"print(x.size())\n",
"x = F.relu(nn.Conv2d(1, 6, 3)(x))\n",
"print(x.size())\n",
"x = F.max_pool2d(x, (2, 2))\n",
"print(x.size())\n",
"x = F.relu(nn.Conv2d(6, 16, 3)(x))\n",
"print(x.size())\n",
"#print(x[0, 0, 8:, 8:]) # max_pool2d ではプールサイズを敷き詰めたとき余った端の領域は無視されることの確認用\n",
"x = F.max_pool2d(x, (2, 2))\n",
"print(x.size())\n",
"#print(x[0, 0, 3:, 3:]) # max_pool2d ではプールサイズを敷き詰めたとき余った端の領域は無視されることの確認用\n",
"size = x.size()[1:]\n",
"num_features = 1\n",
"for s in size:\n",
" num_features *= s\n",
"x = x.view(-1, num_features)\n",
"print(x.size())\n",
"x = F.relu(nn.Linear(16 * 6 * 6, 120)(x))\n",
"print(x.size())\n",
"x = F.relu(nn.Linear(120, 84)(x))\n",
"print(x.size())\n",
"x = nn.Linear(84, 10)(x)\n",
"print(x.size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学習するモデルの定義"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 今回学習するネットワーク構造\n",
"Net(\n",
" (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))\n",
" (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))\n",
" (fc1): Linear(in_features=576, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
")\n",
"\n",
"◆ 今回学習するパラメータたち\n",
"conv1.weight torch.Size([6, 1, 3, 3])\n",
"conv1.bias torch.Size([6])\n",
"conv2.weight torch.Size([16, 6, 3, 3])\n",
"conv2.bias torch.Size([16])\n",
"fc1.weight torch.Size([120, 576])\n",
"fc1.bias torch.Size([120])\n",
"fc2.weight torch.Size([84, 120])\n",
"fc2.bias torch.Size([84])\n",
"fc3.weight torch.Size([10, 84])\n",
"fc3.bias torch.Size([10])\n",
"\n",
"◆ 今回学習するパラメータが初期状態でどうなっているか\n",
"◇ conv1.weight(1つ目のチャネルのみ)\n",
"tensor([[ 0.1507, 0.2545, 0.2658],\n",
" [ 0.0996, -0.1861, 0.3088],\n",
" [-0.0111, -0.2514, -0.0013]], grad_fn=<SliceBackward>)\n",
"◇ conv1.bias(1つ目のチャネルのみ)\n",
"tensor(-0.1463, grad_fn=<SelectBackward>)\n",
"\n",
"◆ 何も学習していないモデルに適当な入力を入れて出力の勾配をとってみる\n",
"◇ 適当な入力(32×32ピクセルの画像がこのバッチに1つだけという状況)\n",
"tensor([[[[ 0.6834, -3.4134, 0.5350, ..., 0.9054, -0.4111, -0.2674],\n",
" [-0.5976, 0.6849, -0.2461, ..., 0.3538, 1.4160, -2.5196],\n",
" [ 0.8288, 0.1629, -0.2127, ..., 1.1652, 0.5502, 1.4269],\n",
" ...,\n",
" [-0.4244, -1.1404, -0.5931, ..., 0.3465, -0.1097, -0.9825],\n",
" [ 0.9708, 0.4426, -0.6627, ..., -0.9357, -0.8525, -0.0245],\n",
" [-0.4908, 0.3119, 1.4955, ..., -0.5074, 0.1405, -0.9770]]]])\n",
"◇ それに対する出力\n",
"tensor([[-1.0325e-01, -1.4090e-02, 1.0086e-01, 6.0341e-02, -1.3714e-01,\n",
" -2.6546e-02, 1.4347e-04, -1.6284e-01, 8.7880e-03, -7.2704e-03]],\n",
" grad_fn=<AddmmBackward>)\n",
"◇ conv1.weight をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)\n",
"tensor([[ 0.0316, 0.0029, -0.0130],\n",
" [ 0.0107, 0.0577, 0.0390],\n",
" [ 0.0594, -0.0445, 0.0349]])\n",
"◇ conv1.bias をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)\n",
"tensor(0.0469)\n",
"\n",
"◆ 勾配をリセットする\n",
"◇ conv1.weight をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)→ リセットされている\n",
"tensor([[0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.]])\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 6, 3)\n",
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
" self.fc1 = nn.Linear(16 * 6 * 6, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n",
" x = x.view(-1, self.num_flat_features(x))\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
" def num_flat_features(self, x):\n",
" size = x.size()[1:] # 1次元目はバッチサイズなのでそれを除いたデータの次元数を調べる\n",
" num_features = 1\n",
" for s in size:\n",
" num_features *= s\n",
" return num_features\n",
"\n",
"net = Net()\n",
"\n",
"print('◆ 今回学習するネットワーク構造')\n",
"print(net)\n",
"\n",
"print('\\n◆ 今回学習するパラメータたち')\n",
"for name, param in net.named_parameters():\n",
" print(name.ljust(14), param.size())\n",
"\n",
"print('\\n◆ 今回学習するパラメータが初期状態でどうなっているか')\n",
"print('◇ conv1.weight(1つ目のチャネルのみ)')\n",
"print(net.conv1.weight[0, 0, :, :])\n",
"print('◇ conv1.bias(1つ目のチャネルのみ)')\n",
"print(net.conv1.bias[0])\n",
" \n",
"print('\\n◆ 何も学習していないモデルに適当な入力を入れて出力の勾配をとってみる')\n",
"print('◇ 適当な入力(32×32ピクセルの画像がこのバッチに1つだけという状況)')\n",
"input = torch.randn(1, 1, 32, 32)\n",
"print(input)\n",
"print('◇ それに対する出力')\n",
"out = net(input)\n",
"print(out)\n",
"out.backward(torch.ones(1, 10)) # out の各成分の和の勾配\n",
"print('◇ conv1.weight をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)')\n",
"print(net.conv1.weight.grad[0, 0, :, :])\n",
"print('◇ conv1.bias をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)')\n",
"print(net.conv1.bias.grad[0])\n",
"\n",
"print('\\n◆ 勾配をリセットする')\n",
"net.zero_grad()\n",
"print('◇ conv1.weight をどの向きに動かすと out の各成分の和は大きくなるか(1つ目のチャネルのみ)→ リセットされている')\n",
"print(net.conv1.weight.grad[0, 0, :, :])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ ダミー入力とダミー正解を用意して損失を計算してみる\n",
"◇ ダミー入力\n",
"tensor([[[[-0.9004, 0.2428, -0.0449, ..., 0.2112, -0.1005, -1.7895],\n",
" [-1.0153, -0.0603, 0.7541, ..., 0.6913, 0.6475, 0.4360],\n",
" [-0.4437, 1.0565, 1.4228, ..., 0.9775, -1.6461, -0.8933],\n",
" ...,\n",
" [ 0.8332, -0.4080, 1.3895, ..., -0.9083, -0.3131, -0.2984],\n",
" [ 0.4416, 0.1449, 1.7584, ..., -1.3662, 0.8630, -0.3609],\n",
" [ 1.1337, -0.3620, -0.3627, ..., 0.3511, -0.1659, 0.4646]]]])\n",
"◇ ダミー正解\n",
"tensor([[-1.6234, 0.6375, -1.1955, -0.8426, -1.6228, 0.1544, -0.4771, 0.2755,\n",
" -0.1288, 0.6331]])\n",
"◇ ダミー予測値\n",
"tensor([[-0.0952, -0.0213, 0.1116, 0.0807, -0.1407, -0.0380, -0.0079, -0.1717,\n",
" 0.0170, -0.0173]], grad_fn=<AddmmBackward>)\n",
"◇ 損失(平均2乗誤差)\n",
"tensor(0.8428, grad_fn=<MseLossBackward>)\n",
"◇ 損失を自分でも検算する\n",
"0.842791190449175\n",
"\n",
"◆ 勾配をリセットしてからさっきの損失の勾配をとってみる\n",
"◇ conv1.weight をどの向きに動かすと loss は大きくなるか(1つ目のチャネルのみ)→ リセットされている\n",
"tensor([[0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.]])\n",
"◇ conv1.weight をどの向きに動かすと loss は大きくなるか(1つ目のチャネルのみ)\n",
"tensor([[-0.0044, 0.0082, 0.0057],\n",
" [ 0.0060, 0.0163, -0.0241],\n",
" [ 0.0014, -0.0092, -0.0040]])\n",
"\n",
"◆ 損失が減少する方向にパラメータを手動で(手動とは)更新してみる → 勾配の逆向きに更新されている\n",
"◇ conv1.weight の更新前の値(1つ目のチャネルのみ)\n",
"tensor([[ 0.1507, 0.2545, 0.2658],\n",
" [ 0.0996, -0.1861, 0.3088],\n",
" [-0.0111, -0.2514, -0.0013]], grad_fn=<SliceBackward>)\n",
"◇ conv1.weight の更新後の値(1つ目のチャネルのみ)\n",
"tensor([[ 0.1512, 0.2537, 0.2652],\n",
" [ 0.0990, -0.1878, 0.3113],\n",
" [-0.0113, -0.2505, -0.0009]], grad_fn=<SliceBackward>)\n"
]
}
],
"source": [
"print('◆ ダミー入力とダミー正解を用意して損失を計算してみる')\n",
"input = torch.randn(1, 1, 32, 32)\n",
"print('◇ ダミー入力')\n",
"print(input)\n",
"target = torch.randn(10)\n",
"target = target.view(1, -1)\n",
"print('◇ ダミー正解')\n",
"print(target)\n",
"output = net(input)\n",
"print('◇ ダミー予測値')\n",
"print(output)\n",
"print('◇ 損失(平均2乗誤差)')\n",
"criterion = nn.MSELoss()\n",
"loss = criterion(output, target)\n",
"print(loss)\n",
"print('◇ 損失を自分でも検算する')\n",
"import numpy as np\n",
"loss_ = []\n",
"for pred, actual in zip(output[0], target[0]):\n",
" loss_.append((pred.item() - actual.item())**2)\n",
"print(np.mean(loss_))\n",
"\n",
"print('\\n◆ 勾配をリセットしてからさっきの損失の勾配をとってみる')\n",
"net.zero_grad()\n",
"print('◇ conv1.weight をどの向きに動かすと loss は大きくなるか(1つ目のチャネルのみ)→ リセットされている')\n",
"print(net.conv1.weight.grad[0, 0, :, :])\n",
"loss.backward()\n",
"print('◇ conv1.weight をどの向きに動かすと loss は大きくなるか(1つ目のチャネルのみ)')\n",
"print(net.conv1.weight.grad[0, 0, :, :])\n",
"\n",
"print('\\n◆ 損失が減少する方向にパラメータを手動で(手動とは)更新してみる → 勾配の逆向きに更新されている')\n",
"learning_rate = 0.1\n",
"print('◇ conv1.weight の更新前の値(1つ目のチャネルのみ)')\n",
"print(net.conv1.weight[0, 0, :, :])\n",
"for f in net.parameters():\n",
" f.data.sub_(f.grad.data * learning_rate) # パラメータの現在の値から、勾配 × 学習率 をマイナスする\n",
"print('◇ conv1.weight の更新後の値(1つ目のチャネルのみ)')\n",
"print(net.conv1.weight[0, 0, :, :])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 実際には手動でネットワークのパラメータを更新するのではなく組み込みのオプティマイザに頼る\n",
"◇ ダミー入力\n",
"tensor([[[[ 1.4960, 0.3085, -0.9062, ..., -0.0718, -0.6352, 0.1182],\n",
" [-0.5006, -0.8628, -0.4996, ..., 0.9583, 0.5484, -0.1718],\n",
" [ 0.6337, 0.0841, 0.2042, ..., 0.1574, -0.7153, -2.1693],\n",
" ...,\n",
" [-0.0395, -1.4018, 1.3060, ..., 2.1325, 0.6054, -0.0117],\n",
" [ 1.1225, 0.9184, -0.5754, ..., 0.9765, -0.7976, 1.9175],\n",
" [-0.0467, -0.1450, 0.6842, ..., -1.3640, -0.3506, 0.1091]]]])\n",
"◇ ダミー正解\n",
"tensor([[ 8.0228e-01, 1.4126e+00, -7.3685e-01, 4.9925e-04, -2.8538e-02,\n",
" 1.4983e+00, 3.8812e-01, -8.6375e-01, -4.6404e-02, 5.9980e-01]])\n",
"◇ 学習のループをまわす\n",
"tensor(0.6753, grad_fn=<MseLossBackward>)\n",
"tensor(0.5702, grad_fn=<MseLossBackward>)\n",
"tensor(0.4575, grad_fn=<MseLossBackward>)\n",
"tensor(0.3053, grad_fn=<MseLossBackward>)\n",
"tensor(0.1177, grad_fn=<MseLossBackward>)\n",
"tensor(0.0240, grad_fn=<MseLossBackward>)\n",
"tensor(0.0056, grad_fn=<MseLossBackward>)\n",
"tensor(0.0015, grad_fn=<MseLossBackward>)\n",
"tensor(0.0006, grad_fn=<MseLossBackward>)\n",
"tensor(0.0003, grad_fn=<MseLossBackward>)\n",
"◇ ダミー予測値\n",
"tensor([[ 0.8078, 1.4422, -0.7402, -0.0166, -0.0344, 1.5241, 0.4023, -0.8833,\n",
" -0.0539, 0.6156]], grad_fn=<AddmmBackward>)\n"
]
}
],
"source": [
"print('◆ 実際には手動でネットワークのパラメータを更新するのではなく組み込みのオプティマイザに頼る')\n",
"\n",
"import torch.optim as optim\n",
"optimizer = optim.SGD(net.parameters(), lr=0.1)\n",
"\n",
"input = torch.randn(1, 1, 32, 32)\n",
"print('◇ ダミー入力')\n",
"print(input)\n",
"target = torch.randn(10)\n",
"target = target.view(1, -1)\n",
"print('◇ ダミー正解')\n",
"print(target)\n",
"\n",
"print('◇ 学習のループをまわす')\n",
"for i in range(10):\n",
" optimizer.zero_grad() # 勾配をリセット\n",
" output = net(input) # 現在の予測値を出力\n",
" loss = criterion(output, target) # 損失を計算\n",
" print(loss)\n",
" loss.backward() # 損失の勾配を計算\n",
" optimizer.step()\n",
" \n",
"print('◇ ダミー予測値')\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@CookieBox26
Copy link
Author

「【参考】conv2d の入力次元数と出力次元数の関係」の右下の式のceilingは式全体にかかっていなければならない(左上の元々の式のように

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment