Created
August 17, 2022 14:59
-
-
Save CookieBox26/195e58558af6c40899db0a6e98f0e105 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "43e6e6ce-268b-48cf-ac48-0ebd6f227706", | |
"metadata": {}, | |
"source": [ | |
"## AlexNet 編\n", | |
"\n", | |
"#### 参考文献\n", | |
"[1] [Probabilistic Machine Learning: An Introduction](https://probml.github.io/pml-book/book1.html) (テキスト)\n", | |
"\n", | |
"#### AlexNet とは [1]\n", | |
"2012 年の ImageNet challenge において誤り率の記録をそれまでのものから塗り替えた畳み込みニューラルネットモデルである。 \n", | |
"訓練には当時で2枚のGPUを要した。\n", | |
"\n", | |
"#### ImageNet の利用について\n", | |
"ImageNet を以下のコードだけかいてロードしようとするとエラーが出る。エラーに「このファイルがない」とあるように [ImageNet 公式](https://image-net.org/index.php) から必要なファイルをダウンロードする必要がある。ダウンロードの許可を得るにはフリーでないメールアドレスを用いてサインアップする必要があるが、何もなければすぐ許可が得られるはずである。しかし、データが大きいため私の環境ではダウンロードに膨大な時間がかかる。\n", | |
"```python\n", | |
"import torchvision\n", | |
"trainset = torchvision.datasets.ImageNet(root=root, train=True, download=True)\n", | |
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", | |
"testset = torchvision.datasets.ImageNet(root=root, train=False, download=True)\n", | |
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "cd800190-8ed0-4cbe-9555-8f2f1027e175", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ テキスト474ページの図14.16を実現するのに必要な層を特定する\n", | |
"※ バッチサイズ32, 3チャネル, 224x224ピクセルのダミーデータを流す。\n", | |
"入力直後\n", | |
"torch.Size([32, 3, 224, 224])\n", | |
"1番目の畳み込み後\n", | |
"torch.Size([32, 96, 54, 54])\n", | |
"1番目のプール後\n", | |
"torch.Size([32, 96, 26, 26])\n", | |
"2番目の畳み込み後\n", | |
"torch.Size([32, 256, 26, 26])\n", | |
"2番目のプール後\n", | |
"torch.Size([32, 256, 12, 12])\n", | |
"3番目の畳み込み後\n", | |
"torch.Size([32, 384, 12, 12])\n", | |
"4番目の畳み込み後\n", | |
"torch.Size([32, 384, 12, 12])\n", | |
"5番目の畳み込み後\n", | |
"torch.Size([32, 384, 12, 12])\n", | |
"3番目のプール後\n", | |
"torch.Size([32, 384, 5, 5])\n", | |
"リシェイプした後\n", | |
"torch.Size([32, 9600])\n", | |
"1番目の全結合した後\n", | |
"torch.Size([32, 4096])\n", | |
"2番目の全結合した後\n", | |
"torch.Size([32, 4096])\n", | |
"3番目の全結合した後\n", | |
"torch.Size([32, 1000])\n" | |
] | |
} | |
], | |
"source": [ | |
"import warnings\n", | |
"warnings.simplefilter('ignore')\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"print('◆ テキスト474ページの図14.16を実現するのに必要な層を特定する')\n", | |
"batch_size = 32\n", | |
"channel = 3\n", | |
"height = 224\n", | |
"width = 224\n", | |
"print(f'※ バッチサイズ{batch_size}, {channel}チャネル, {height}x{width}ピクセルのダミーデータを流す。')\n", | |
"x = torch.randn(batch_size, channel, height, width)\n", | |
"print(f'入力直後\\n{x.size()}')\n", | |
"x = F.relu(nn.Conv2d(3, 96, 11, stride=4)(x))\n", | |
"print(f'1番目の畳み込み後\\n{x.size()}')\n", | |
"x = nn.MaxPool2d((3, 3), stride=2)(x)\n", | |
"print(f'1番目のプール後\\n{x.size()}')\n", | |
"x = F.relu(nn.Conv2d(96, 256, 5, padding=2)(x))\n", | |
"print(f'2番目の畳み込み後\\n{x.size()}')\n", | |
"x = nn.MaxPool2d((3, 3), stride=2)(x)\n", | |
"print(f'2番目のプール後\\n{x.size()}')\n", | |
"x = F.relu(nn.Conv2d(256, 384, 3, padding=1)(x))\n", | |
"print(f'3番目の畳み込み後\\n{x.size()}')\n", | |
"x = F.relu(nn.Conv2d(384, 384, 3, padding=1)(x))\n", | |
"print(f'4番目の畳み込み後\\n{x.size()}')\n", | |
"x = F.relu(nn.Conv2d(384, 384, 3, padding=1)(x))\n", | |
"print(f'5番目の畳み込み後\\n{x.size()}')\n", | |
"x = nn.MaxPool2d((3, 3), stride=2)(x)\n", | |
"print(f'3番目のプール後\\n{x.size()}')\n", | |
"size = x.size()[1:]\n", | |
"num_features = 1\n", | |
"for s in size:\n", | |
" num_features *= s\n", | |
"x = x.view(-1, num_features)\n", | |
"print(f'リシェイプした後\\n{x.size()}')\n", | |
"x = F.relu(nn.Linear(num_features, 4096)(x))\n", | |
"print(f'1番目の全結合した後\\n{x.size()}')\n", | |
"x = F.relu(nn.Linear(4096, 4096)(x))\n", | |
"print(f'2番目の全結合した後\\n{x.size()}')\n", | |
"x = nn.Linear(4096, 1000)(x)\n", | |
"print(f'3番目の全結合した後\\n{x.size()}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "ff33ac99-9790-4910-aaa5-7d2ead76d020", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ 今回学習するネットワーク\n", | |
"AlexNet(\n", | |
" (max_pool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
" (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))\n", | |
" (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", | |
" (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (conv5): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (fc1): Linear(in_features=9600, out_features=4096, bias=True)\n", | |
" (fc2): Linear(in_features=4096, out_features=4096, bias=True)\n", | |
" (fc3): Linear(in_features=4096, out_features=1000, bias=True)\n", | |
")\n", | |
"\n", | |
"◆ 今回学習するパラメータたち\n", | |
"conv1.weight torch.Size([96, 3, 11, 11])\n", | |
"conv1.bias torch.Size([96])\n", | |
"conv2.weight torch.Size([256, 96, 5, 5])\n", | |
"conv2.bias torch.Size([256])\n", | |
"conv3.weight torch.Size([384, 256, 3, 3])\n", | |
"conv3.bias torch.Size([384])\n", | |
"conv4.weight torch.Size([384, 384, 3, 3])\n", | |
"conv4.bias torch.Size([384])\n", | |
"conv5.weight torch.Size([384, 384, 3, 3])\n", | |
"conv5.bias torch.Size([384])\n", | |
"fc1.weight torch.Size([4096, 9600])\n", | |
"fc1.bias torch.Size([4096])\n", | |
"fc2.weight torch.Size([4096, 4096])\n", | |
"fc2.bias torch.Size([4096])\n", | |
"fc3.weight torch.Size([1000, 4096])\n", | |
"fc3.bias torch.Size([1000])\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class AlexNet(nn.Module):\n", | |
"\n", | |
" def __init__(self):\n", | |
" super(AlexNet, self).__init__()\n", | |
" self.max_pool = nn.MaxPool2d((3, 3), stride=2)\n", | |
" self.conv1 = nn.Conv2d(3, 96, 11, stride=4)\n", | |
" self.conv2 = nn.Conv2d(96, 256, 5, padding=2)\n", | |
" self.conv3 = nn.Conv2d(256, 384, 3, padding=1)\n", | |
" self.conv4 = nn.Conv2d(384, 384, 3, padding=1)\n", | |
" self.conv5 = nn.Conv2d(384, 384, 3, padding=1)\n", | |
" self.fc1 = nn.Linear(9600, 4096) # TODO: ここが 224x224 ピクセル決め打ちになっている\n", | |
" self.fc2 = nn.Linear(4096, 4096)\n", | |
" self.fc3 = nn.Linear(4096, 1000)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # TODO: ドロップアウト\n", | |
" x = self.max_pool(F.relu(self.conv1(x)))\n", | |
" x = self.max_pool(F.relu(self.conv2(x)))\n", | |
" x = F.relu(self.conv3(x))\n", | |
" x = F.relu(self.conv4(x))\n", | |
" x = F.relu(self.conv5(x))\n", | |
" x = x.view(-1, 9600)\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
" x = self.fc3(x)\n", | |
" return x\n", | |
"\n", | |
"net = AlexNet()\n", | |
"\n", | |
"print('◆ 今回学習するネットワーク')\n", | |
"print(net)\n", | |
"print('\\n◆ 今回学習するパラメータたち')\n", | |
"for name, param in net.named_parameters():\n", | |
" print(name.ljust(14), param.size())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "75e79400-39d3-4ac6-8b55-4bb514fbc202", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment