Last active
June 23, 2021 09:10
-
-
Save tttamaki/64ecf3611228f8c1fc941efabd12628e to your computer and use it in GitHub Desktop.
事前学習済みResNet50を切り貼りしてABNを作ってみた
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
}, | |
"orig_nbformat": 4, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3.8.8 64-bit ('base': conda)" | |
}, | |
"interpreter": { | |
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"cells": [ | |
{ | |
"source": [ | |
"事前学習済みのResNet50をバラバラにして,ABNを作ってみる." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import torchvision\n", | |
"from torchvision import transforms, models, datasets\n", | |
"\n", | |
"import torchinfo\n", | |
"import torchsummary\n", | |
"import copy\n", | |
"from tqdm.notebook import tqdm" | |
], | |
"cell_type": "code", | |
"metadata": {}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"transform_train = transforms.Compose([\n", | |
" transforms.RandomCrop(32, padding=4),\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n", | |
" std=[0.229, 0.224, 0.225]),\n", | |
" ])\n", | |
"transform_val = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n", | |
" std=[0.229, 0.224, 0.225]),\n", | |
" ])\n", | |
"\n", | |
"num_classes = 100\n", | |
"batch_size = 2048\n", | |
"num_workers = 24\n", | |
"\n", | |
"train_set = torchvision.datasets.CIFAR100(\n", | |
" root='./data',\n", | |
" train=True,\n", | |
" download=False,\n", | |
" transform=transform_train\n", | |
")\n", | |
"val_set = torchvision.datasets.CIFAR100(\n", | |
" root='./data',\n", | |
" train=False,\n", | |
" download=False,\n", | |
" transform=transform_val\n", | |
")\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,\n", | |
" shuffle=True, num_workers=num_workers)\n", | |
"val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,\n", | |
" shuffle=True, num_workers=num_workers)\n" | |
] | |
}, | |
{ | |
"source": [ | |
"まずはoriginalのResNet50.最終層だけfine-tune用に付け替える." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |
"model_org = models.resnet50(pretrained=True)\n", | |
"model_org.fc = nn.Linear(model_org.fc.in_features, num_classes)\n", | |
"model_org = model_org.to(device)" | |
] | |
}, | |
{ | |
"source": [ | |
"次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ReconstructResNet50(nn.Module):\n", | |
" def __init__(self, num_classes: int = 1000, fc = None) -> None:\n", | |
" super().__init__()\n", | |
" model = models.resnet50(pretrained=True)\n", | |
" if fc == None:\n", | |
" model.fc = nn.Linear(model.fc.in_features, num_classes)\n", | |
" else:\n", | |
" model.fc = copy.deepcopy(fc)\n", | |
"\n", | |
" self.resnet50_bottop_half = nn.Sequential(\n", | |
" model.conv1,\n", | |
" model.bn1,\n", | |
" model.relu,\n", | |
" model.maxpool,\n", | |
" model.layer1,\n", | |
" model.layer2\n", | |
" )\n", | |
"\n", | |
" self.resnet50_top_half = nn.Sequential(\n", | |
" model.layer3,\n", | |
" model.layer4,\n", | |
" model.avgpool,\n", | |
" nn.Flatten(),\n", | |
" model.fc\n", | |
" )\n", | |
"\n", | |
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n", | |
" x = self.resnet50_bottop_half(x)\n", | |
" x = self.resnet50_top_half(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"source": [ | |
"fine-tune用に付けかえた最終層はoriginalのものを借用(deepcopyする).\n", | |
"これをせずに新たにnn.Linearを作ってしまうと,初期化がランダムのために,同じ出力が得られず確認できない..." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_new = ReconstructResNet50(\n", | |
" num_classes=num_classes,\n", | |
" fc=model_org.fc\n", | |
").to(device)" | |
] | |
}, | |
{ | |
"source": [ | |
"originalのモデルの中身を確認" | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"==========================================================================================\n", | |
"Layer (type (var_name)) Input Shape Output Shape\n", | |
"==========================================================================================\n", | |
"ResNet -- --\n", | |
"├─Conv2d (conv1) [64, 3, 32, 32] [64, 64, 16, 16]\n", | |
"├─BatchNorm2d (bn1) [64, 64, 16, 16] [64, 64, 16, 16]\n", | |
"├─ReLU (relu) [64, 64, 16, 16] [64, 64, 16, 16]\n", | |
"├─MaxPool2d (maxpool) [64, 64, 16, 16] [64, 64, 8, 8]\n", | |
"├─Sequential (layer1) [64, 64, 8, 8] [64, 256, 8, 8]\n", | |
"│ └─Bottleneck (0) [64, 64, 8, 8] [64, 256, 8, 8]\n", | |
"│ └─Bottleneck (1) [64, 256, 8, 8] [64, 256, 8, 8]\n", | |
"│ └─Bottleneck (2) [64, 256, 8, 8] [64, 256, 8, 8]\n", | |
"├─Sequential (layer2) [64, 256, 8, 8] [64, 512, 4, 4]\n", | |
"│ └─Bottleneck (0) [64, 256, 8, 8] [64, 512, 4, 4]\n", | |
"│ └─Bottleneck (1) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"│ └─Bottleneck (2) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"│ └─Bottleneck (3) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"├─Sequential (layer3) [64, 512, 4, 4] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (1) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (2) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (3) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (4) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ └─Bottleneck (5) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"├─Sequential (layer4) [64, 1024, 2, 2] [64, 2048, 1, 1]\n", | |
"│ └─Bottleneck (0) [64, 1024, 2, 2] [64, 2048, 1, 1]\n", | |
"│ └─Bottleneck (1) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"│ └─Bottleneck (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"├─AdaptiveAvgPool2d (avgpool) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"├─Linear (fc) [64, 2048] [64, 100]\n", | |
"==========================================================================================\n", | |
"Total params: 23,712,932\n", | |
"Trainable params: 23,712,932\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (G): 5.35\n", | |
"==========================================================================================\n", | |
"Input size (MB): 0.79\n", | |
"Forward/backward pass size (MB): 232.31\n", | |
"Params size (MB): 94.85\n", | |
"Estimated Total Size (MB): 327.95\n", | |
"==========================================================================================" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
], | |
"source": [ | |
"torchinfo.summary(\n", | |
" model_org,\n", | |
" (64, 3, 32, 32),\n", | |
" depth=2,\n", | |
" col_names=[\"input_size\",\n", | |
" \"output_size\"],\n", | |
" row_settings=(\"var_names\",)\n", | |
" )" | |
] | |
}, | |
{ | |
"source": [ | |
"再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"===============================================================================================\n", | |
"Layer (type (var_name)) Input Shape Output Shape\n", | |
"===============================================================================================\n", | |
"ReconstructResNet50 -- --\n", | |
"├─Sequential (resnet50_bottop_half) [64, 3, 32, 32] [64, 512, 4, 4]\n", | |
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 64, 16, 16]\n", | |
"│ └─BatchNorm2d (1) [64, 64, 16, 16] [64, 64, 16, 16]\n", | |
"│ └─ReLU (2) [64, 64, 16, 16] [64, 64, 16, 16]\n", | |
"│ └─MaxPool2d (3) [64, 64, 16, 16] [64, 64, 8, 8]\n", | |
"│ └─Sequential (4) [64, 64, 8, 8] [64, 256, 8, 8]\n", | |
"│ │ └─Bottleneck (0) [64, 64, 8, 8] [64, 256, 8, 8]\n", | |
"│ │ └─Bottleneck (1) [64, 256, 8, 8] [64, 256, 8, 8]\n", | |
"│ │ └─Bottleneck (2) [64, 256, 8, 8] [64, 256, 8, 8]\n", | |
"│ └─Sequential (5) [64, 256, 8, 8] [64, 512, 4, 4]\n", | |
"│ │ └─Bottleneck (0) [64, 256, 8, 8] [64, 512, 4, 4]\n", | |
"│ │ └─Bottleneck (1) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"│ │ └─Bottleneck (2) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"│ │ └─Bottleneck (3) [64, 512, 4, 4] [64, 512, 4, 4]\n", | |
"├─Sequential (resnet50_top_half) [64, 512, 4, 4] [64, 100]\n", | |
"│ └─Sequential (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (1) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (2) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (3) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (4) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ │ └─Bottleneck (5) [64, 1024, 2, 2] [64, 1024, 2, 2]\n", | |
"│ └─Sequential (1) [64, 1024, 2, 2] [64, 2048, 1, 1]\n", | |
"│ │ └─Bottleneck (0) [64, 1024, 2, 2] [64, 2048, 1, 1]\n", | |
"│ │ └─Bottleneck (1) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"│ │ └─Bottleneck (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"│ └─AdaptiveAvgPool2d (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n", | |
"│ └─Flatten (3) [64, 2048, 1, 1] [64, 2048]\n", | |
"│ └─Linear (4) [64, 2048] [64, 100]\n", | |
"===============================================================================================\n", | |
"Total params: 23,712,932\n", | |
"Trainable params: 23,712,932\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (G): 5.35\n", | |
"===============================================================================================\n", | |
"Input size (MB): 0.79\n", | |
"Forward/backward pass size (MB): 232.31\n", | |
"Params size (MB): 94.85\n", | |
"Estimated Total Size (MB): 327.95\n", | |
"===============================================================================================" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
], | |
"source": [ | |
"torchinfo.summary(\n", | |
" model_new,\n", | |
" (64, 3, 32, 32),\n", | |
" depth=3,\n", | |
" col_names=[\"input_size\",\n", | |
" \"output_size\"],\n", | |
" row_settings=(\"var_names\",)\n", | |
" )" | |
] | |
}, | |
{ | |
"source": [ | |
"ではデータを流し込んでみる" | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = torch.randn(64, 3, 32, 32).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.return_types.max(\nvalues=tensor([1.3454, 1.1356, 1.2690, 1.3834, 1.0672, 1.2274, 1.0207, 1.0614, 1.3978,\n 1.2313, 1.4156, 1.4666, 1.2339, 1.0205, 0.9802, 1.0208, 1.2230, 1.1357,\n 1.1173, 1.1000, 1.2028, 1.2953, 1.0749, 1.1699, 1.2398, 1.2197, 1.3249,\n 1.3639, 1.3710, 1.4766, 1.0805, 1.2381, 1.3166, 1.3227, 1.2800, 1.3676,\n 1.1806, 1.2021, 1.2806, 1.2725, 1.0955, 1.3845, 1.3075, 1.1250, 1.2184,\n 1.4227, 1.0932, 1.1074, 1.1536, 1.1790, 1.3503, 1.3610, 1.2805, 1.2181,\n 1.3869, 1.2473, 1.2844, 1.2438, 1.0619, 1.3060, 1.1839, 1.2812, 1.1054,\n 1.1871], device='cuda:0', grad_fn=<MaxBackward0>),\nindices=tensor([90, 78, 78, 78, 78, 44, 88, 57, 44, 78, 78, 78, 78, 78, 27, 88, 44, 44,\n 27, 44, 90, 78, 44, 90, 27, 44, 44, 78, 78, 44, 27, 78, 78, 78, 44, 78,\n 78, 27, 90, 78, 84, 44, 44, 78, 78, 78, 78, 27, 78, 44, 44, 44, 90, 78,\n 44, 90, 78, 78, 84, 78, 44, 44, 27, 90], device='cuda:0'))\ntensor([[-0.1791, 0.0600, 0.2247, -0.0900, -0.6020],\n [-0.4647, 0.0351, 0.2943, -0.2365, -0.3093],\n [-0.4240, -0.1189, 0.5508, -0.4110, -0.5454],\n [-0.6184, 0.0310, 0.0049, -0.0531, -0.7490],\n [-0.3527, 0.1487, 0.3425, -0.0758, -0.5206],\n [-0.3394, 0.2014, 0.3621, -0.1127, -0.3296],\n [-0.6504, 0.2399, 0.4901, -0.1620, -0.3928],\n [-0.4113, 0.3839, 0.2992, -0.3384, -0.4561],\n [-0.4094, 0.1961, 0.1317, 0.1345, -0.6089],\n [-0.2959, 0.3001, 0.2256, -0.4830, -0.5723]], device='cuda:0',\n grad_fn=<SliceBackward>)\n" | |
] | |
} | |
], | |
"source": [ | |
"model_new.eval()\n", | |
"print(model_new(data).max(axis=1))\n", | |
"print(model_new(data)[:10, :5])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.return_types.max(\n", | |
"values=tensor([1.3454, 1.1356, 1.2690, 1.3834, 1.0672, 1.2274, 1.0207, 1.0614, 1.3978,\n", | |
" 1.2313, 1.4156, 1.4666, 1.2339, 1.0205, 0.9802, 1.0208, 1.2230, 1.1357,\n", | |
" 1.1173, 1.1000, 1.2028, 1.2953, 1.0749, 1.1699, 1.2398, 1.2197, 1.3249,\n", | |
" 1.3639, 1.3710, 1.4766, 1.0805, 1.2381, 1.3166, 1.3227, 1.2800, 1.3676,\n", | |
" 1.1806, 1.2021, 1.2806, 1.2725, 1.0955, 1.3845, 1.3075, 1.1250, 1.2184,\n", | |
" 1.4227, 1.0932, 1.1074, 1.1536, 1.1790, 1.3503, 1.3610, 1.2805, 1.2181,\n", | |
" 1.3869, 1.2473, 1.2844, 1.2438, 1.0619, 1.3060, 1.1839, 1.2812, 1.1054,\n", | |
" 1.1871], device='cuda:0', grad_fn=<MaxBackward0>),\n", | |
"indices=tensor([90, 78, 78, 78, 78, 44, 88, 57, 44, 78, 78, 78, 78, 78, 27, 88, 44, 44,\n", | |
" 27, 44, 90, 78, 44, 90, 27, 44, 44, 78, 78, 44, 27, 78, 78, 78, 44, 78,\n", | |
" 78, 27, 90, 78, 84, 44, 44, 78, 78, 78, 78, 27, 78, 44, 44, 44, 90, 78,\n", | |
" 44, 90, 78, 78, 84, 78, 44, 44, 27, 90], device='cuda:0'))\n", | |
"tensor([[-0.1791, 0.0600, 0.2247, -0.0900, -0.6020],\n", | |
" [-0.4647, 0.0351, 0.2943, -0.2365, -0.3093],\n", | |
" [-0.4240, -0.1189, 0.5508, -0.4110, -0.5454],\n", | |
" [-0.6184, 0.0310, 0.0049, -0.0531, -0.7490],\n", | |
" [-0.3527, 0.1487, 0.3425, -0.0758, -0.5206],\n", | |
" [-0.3394, 0.2014, 0.3621, -0.1127, -0.3296],\n", | |
" [-0.6504, 0.2399, 0.4901, -0.1620, -0.3928],\n", | |
" [-0.4113, 0.3839, 0.2992, -0.3384, -0.4561],\n", | |
" [-0.4094, 0.1961, 0.1317, 0.1345, -0.6089],\n", | |
" [-0.2959, 0.3001, 0.2256, -0.4830, -0.5723]], device='cuda:0',\n", | |
" grad_fn=<SliceBackward>)\n" | |
] | |
} | |
], | |
"source": [ | |
"model_org.eval()\n", | |
"print(model_org(data).max(axis=1))\n", | |
"print(model_org(data)[:10, :5])" | |
] | |
}, | |
{ | |
"source": [ | |
"どちらも同じ出力が得られたことが分かる.\n", | |
"\n", | |
"ではoptimizerを設定して,勾配も一致するかどうか確認する." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n", | |
"optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"target = torch.randint(num_classes, (64,)).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(4.7663, device='cuda:0', grad_fn=<NllLossBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
], | |
"source": [ | |
"model_new.train()\n", | |
"optimizer_new.zero_grad()\n", | |
"output_new = model_new(data)\n", | |
"loss_new = criterion(output_new, target)\n", | |
"loss_new.backward()\n", | |
"loss_new" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(4.7663, device='cuda:0', grad_fn=<NllLossBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
], | |
"source": [ | |
"model_org.train()\n", | |
"optimizer_org.zero_grad()\n", | |
"output_org = model_org(data)\n", | |
"loss_org = criterion(output_org, target)\n", | |
"loss_org.backward()\n", | |
"loss_org" | |
] | |
}, | |
{ | |
"source": [ | |
"lossは一致した.\n", | |
"では重みとその勾配は一致するかどうかを確認する.対象は最初のconvの5x5に限定." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0133, 0.0147, -0.0154, -0.0230, -0.0409],\n", | |
" [ 0.0041, 0.0058, 0.0149, 0.0206, 0.0022],\n", | |
" [ 0.0223, 0.0236, 0.0161, 0.0588, 0.1028],\n", | |
" [ 0.0232, 0.0042, -0.0459, -0.0487, -0.0164],\n", | |
" [-0.0009, 0.0278, -0.0101, -0.0554, -0.1272]], device='cuda:0',\n", | |
" grad_fn=<SliceBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 15 | |
} | |
], | |
"source": [ | |
"model_new.resnet50_bottop_half[0].weight[0, 0, :5, :5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0133, 0.0147, -0.0154, -0.0230, -0.0409],\n", | |
" [ 0.0041, 0.0058, 0.0149, 0.0206, 0.0022],\n", | |
" [ 0.0223, 0.0236, 0.0161, 0.0588, 0.1028],\n", | |
" [ 0.0232, 0.0042, -0.0459, -0.0487, -0.0164],\n", | |
" [-0.0009, 0.0278, -0.0101, -0.0554, -0.1272]], device='cuda:0',\n", | |
" grad_fn=<SliceBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
], | |
"source": [ | |
"model_org.conv1.weight[0, 0, :5, :5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0495, -0.0230, 0.0685, 0.0177, -0.0276],\n", | |
" [ 0.0195, -0.0643, 0.0361, 0.0360, -0.0751],\n", | |
" [ 0.0223, 0.0310, 0.0489, 0.0656, -0.1489],\n", | |
" [-0.0438, 0.1419, 0.0153, -0.1075, 0.1269],\n", | |
" [-0.1402, -0.0279, -0.1216, 0.1160, 0.0147]], device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
], | |
"source": [ | |
"model_new.resnet50_bottop_half[0].weight.grad[0, 0, :5, :5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0495, -0.0230, 0.0685, 0.0177, -0.0276],\n", | |
" [ 0.0195, -0.0643, 0.0361, 0.0360, -0.0751],\n", | |
" [ 0.0223, 0.0310, 0.0489, 0.0656, -0.1489],\n", | |
" [-0.0438, 0.1419, 0.0153, -0.1075, 0.1269],\n", | |
" [-0.1402, -0.0279, -0.1216, 0.1160, 0.0147]], device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
], | |
"source": [ | |
"model_org.conv1.weight.grad[0, 0, :5, :5]" | |
] | |
}, | |
{ | |
"source": [ | |
"これで一致したことが確認できた.\n", | |
"\n", | |
"では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"\n", | |
"class ABNResNet50(nn.Module):\n", | |
" def __init__(self, num_classes: int = 1000) -> None:\n", | |
" super().__init__()\n", | |
" model = models.resnet50(pretrained=True)\n", | |
"\n", | |
" self.resnet50_bottop = nn.Sequential(\n", | |
" model.conv1,\n", | |
" model.bn1,\n", | |
" model.relu,\n", | |
" model.maxpool,\n", | |
" model.layer1,\n", | |
" model.layer2\n", | |
" )\n", | |
"\n", | |
" self.resnet50_top = nn.Sequential(\n", | |
" model.layer3,\n", | |
" model.layer4,\n", | |
" model.avgpool,\n", | |
" nn.Flatten(),\n", | |
" nn.Linear(model.fc.in_features, num_classes)\n", | |
" )\n", | |
"\n", | |
" self.attention_branch1 = nn.Sequential(\n", | |
" copy.deepcopy(model.layer2[3]), # deepcopyしないと,上で使ったものと重みが共有されてしまう\n", | |
" nn.BatchNorm2d(512),\n", | |
" nn.Conv2d(512, 512, kernel_size=1),\n", | |
" nn.ReLU(inplace=True)\n", | |
" )\n", | |
" self.attention_branch2 = nn.Sequential(\n", | |
" nn.Conv2d(512, 512, kernel_size=1),\n", | |
" nn.BatchNorm2d(512),\n", | |
" nn.Sigmoid()\n", | |
" )\n", | |
" self.attention_branch3 = nn.Sequential(\n", | |
" nn.Conv2d(512, num_classes, kernel_size=1),\n", | |
" )\n", | |
"\n", | |
" def get_attn(self):\n", | |
" return self.attn\n", | |
"\n", | |
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n", | |
" x = self.resnet50_bottop(x)\n", | |
"\n", | |
" ax = self.attention_branch1(x)\n", | |
" attn = self.attention_branch2(ax)\n", | |
"\n", | |
" x = x * attn\n", | |
" x = self.resnet50_top(x)\n", | |
"\n", | |
" ax = self.attention_branch3(ax)\n", | |
" ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze()\n", | |
"\n", | |
" self.attn = attn\n", | |
" return x, ax, attn\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_new3 = ABNResNet50(num_classes=num_classes).to(device)\n", | |
"model_new3 = nn.DataParallel(model_new3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer_new3 = torch.optim.SGD(model_new3.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n", | |
"# optimizer_new3 = torch.optim.Adam(model_new3.parameters(), lr=0.1, weight_decay=5e-4)" | |
] | |
}, | |
{ | |
"source": [ | |
"便利関数を作っておく." | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AverageMeter(object):\n", | |
" \"\"\"\n", | |
" Computes and stores the average and current value\n", | |
" Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262\n", | |
" https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59\n", | |
" \"\"\"\n", | |
" def __init__(self):\n", | |
" self.reset()\n", | |
"\n", | |
" def reset(self):\n", | |
" self.val = 0\n", | |
" self.avg = 0\n", | |
" self.sum = 0\n", | |
" self.count = 0\n", | |
"\n", | |
" def update(self, val, n=1):\n", | |
" if type(val) == torch.Tensor:\n", | |
" val = val.item()\n", | |
" self.val = val\n", | |
" self.sum += val * n\n", | |
" self.count += n\n", | |
" self.avg = self.sum / self.count\n", | |
"\n", | |
"def top1(outputs, targets):\n", | |
" batch_size = outputs.size(0)\n", | |
" _, predicted = outputs.max(1)\n", | |
" return predicted.eq(targets).sum().item() / batch_size" | |
] | |
}, | |
{ | |
"source": [ | |
"CIFAR100で学習" | |
], | |
"cell_type": "markdown", | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "aecc7007ad984f12aedf6ad3cd5ed2b6" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "9913ee6cf57449f0a4dd65b09051205c" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "7cf38e5bf4084979ba124aae7f57e0d9" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "17763f0424d84d0ba426206c66fb8d28" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "7f64b12b947349858bb4f3337d6ebccb" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "f5ea7701afc043d697238db42345879e" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "9689ab0caf904f49afab76134c39aab9" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "90d2e07b71e241c89318c00505929693" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "b61432c10ec748babaa6405f35f8d407" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "a2592030922d4c68bac0d91ef13b1dab" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "22759d1dcea1473b8c2e1e01df558ba3" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n\n" | |
] | |
} | |
], | |
"source": [ | |
"num_epochs = 5\n", | |
"\n", | |
"with tqdm(range(num_epochs)) as pbar_epoch:\n", | |
" for epoch in pbar_epoch:\n", | |
" pbar_epoch.set_description(\"[Epoch %d]\" % (epoch))\n", | |
"\n", | |
"\n", | |
" with tqdm(enumerate(train_loader),\n", | |
" total=len(train_loader),\n", | |
" leave=True) as pbar_loss:\n", | |
"\n", | |
" train_loss_new3 = AverageMeter()\n", | |
" train_acc_new3 = AverageMeter()\n", | |
" correct_new3 = AverageMeter()\n", | |
" model_new3.train()\n", | |
"\n", | |
" for batch_idx, (inputs, targets) in pbar_loss:\n", | |
" pbar_loss.set_description(\"[train]\")\n", | |
"\n", | |
" inputs, targets = inputs.to(device), targets.to(device)\n", | |
" bs = inputs.size(0) # current batch size, may vary at the end of the epoch\n", | |
"\n", | |
" optimizer_new3.zero_grad()\n", | |
" outputs_x_new3, outputs_ax_new3, _ = model_new3(inputs)\n", | |
" loss_x_new3 = criterion(outputs_x_new3, targets)\n", | |
" loss_ax_new3 = criterion(outputs_ax_new3, targets)\n", | |
" loss_new3 = (loss_x_new3 + loss_ax_new3) / 2\n", | |
" loss_new3.backward()\n", | |
" optimizer_new3.step()\n", | |
" train_loss_new3.update(loss_new3, bs)\n", | |
" train_acc_new3.update(top1(outputs_x_new3, targets), bs)\n", | |
"\n", | |
" pbar_loss.set_postfix_str(\n", | |
" ' | loss={:6.04f} , top1={:6.04f}'\n", | |
" ''.format(\n", | |
" train_loss_new3.avg, train_acc_new3.avg,\n", | |
" ))\n", | |
"\n", | |
"\n", | |
" with torch.no_grad(), \\\n", | |
" tqdm(enumerate(val_loader),\n", | |
" total=len(val_loader),\n", | |
" leave=True) as pbar_loss:\n", | |
"\n", | |
" train_loss_new3 = AverageMeter()\n", | |
" train_acc_new3 = AverageMeter()\n", | |
" correct_new3 = AverageMeter()\n", | |
" model_new3.eval()\n", | |
"\n", | |
" for batch_idx, (inputs, targets) in pbar_loss:\n", | |
" pbar_loss.set_description(\"[val]\")\n", | |
"\n", | |
" inputs, targets = inputs.to(device), targets.to(device)\n", | |
"\n", | |
" # optimizer_new3.zero_grad()\n", | |
" outputs_x_new3, outputs_ax_new3, _ = model_new3(inputs)\n", | |
" loss_x_new3 = criterion(outputs_x_new3, targets)\n", | |
" loss_ax_new3 = criterion(outputs_ax_new3, targets)\n", | |
" loss_new3 = (loss_x_new3 + loss_ax_new3) / 2\n", | |
" # loss_new3.backward()\n", | |
" # optimizer_new3.step()\n", | |
" train_loss_new3.update(loss_new3, bs)\n", | |
" train_acc_new3.update(top1(outputs_x_new3, targets), bs)\n", | |
"\n", | |
"\n", | |
"\n", | |
" pbar_loss.set_postfix_str(\n", | |
" ' | loss={:6.04f} , top1={:6.04f}'\n", | |
" ''.format(\n", | |
" train_loss_new3.avg, train_acc_new3.avg,\n", | |
" ))\n", | |
"\n", | |
"\n", | |
"\n" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment