Skip to content

Instantly share code, notes, and snippets.

@tttamaki
Last active June 23, 2021 09:10
Show Gist options
  • Save tttamaki/64ecf3611228f8c1fc941efabd12628e to your computer and use it in GitHub Desktop.
Save tttamaki/64ecf3611228f8c1fc941efabd12628e to your computer and use it in GitHub Desktop.
事前学習済みResNet50を切り貼りしてABNを作ってみた
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.8 64-bit ('base': conda)"
},
"interpreter": {
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"事前学習済みのResNet50をバラバラにして,ABNを作ってみる."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchvision\n",
"from torchvision import transforms, models, datasets\n",
"\n",
"import torchinfo\n",
"import torchsummary\n",
"import copy\n",
"from tqdm.notebook import tqdm"
],
"cell_type": "code",
"metadata": {},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"transform_train = transforms.Compose([\n",
" transforms.RandomCrop(32, padding=4),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n",
" std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"transform_val = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n",
" std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"\n",
"num_classes = 100\n",
"batch_size = 2048\n",
"num_workers = 24\n",
"\n",
"train_set = torchvision.datasets.CIFAR100(\n",
" root='./data',\n",
" train=True,\n",
" download=False,\n",
" transform=transform_train\n",
")\n",
"val_set = torchvision.datasets.CIFAR100(\n",
" root='./data',\n",
" train=False,\n",
" download=False,\n",
" transform=transform_val\n",
")\n",
"\n",
"train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,\n",
" shuffle=True, num_workers=num_workers)\n",
"val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,\n",
" shuffle=True, num_workers=num_workers)\n"
]
},
{
"source": [
"まずはoriginalのResNet50.最終層だけfine-tune用に付け替える."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model_org = models.resnet50(pretrained=True)\n",
"model_org.fc = nn.Linear(model_org.fc.in_features, num_classes)\n",
"model_org = model_org.to(device)"
]
},
{
"source": [
"次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class ReconstructResNet50(nn.Module):\n",
" def __init__(self, num_classes: int = 1000, fc = None) -> None:\n",
" super().__init__()\n",
" model = models.resnet50(pretrained=True)\n",
" if fc == None:\n",
" model.fc = nn.Linear(model.fc.in_features, num_classes)\n",
" else:\n",
" model.fc = copy.deepcopy(fc)\n",
"\n",
" self.resnet50_bottop_half = nn.Sequential(\n",
" model.conv1,\n",
" model.bn1,\n",
" model.relu,\n",
" model.maxpool,\n",
" model.layer1,\n",
" model.layer2\n",
" )\n",
"\n",
" self.resnet50_top_half = nn.Sequential(\n",
" model.layer3,\n",
" model.layer4,\n",
" model.avgpool,\n",
" nn.Flatten(),\n",
" model.fc\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" x = self.resnet50_bottop_half(x)\n",
" x = self.resnet50_top_half(x)\n",
" return x"
]
},
{
"source": [
"fine-tune用に付けかえた最終層はoriginalのものを借用(deepcopyする).\n",
"これをせずに新たにnn.Linearを作ってしまうと,初期化がランダムのために,同じ出力が得られず確認できない..."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model_new = ReconstructResNet50(\n",
" num_classes=num_classes,\n",
" fc=model_org.fc\n",
").to(device)"
]
},
{
"source": [
"originalのモデルの中身を確認"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type (var_name)) Input Shape Output Shape\n",
"==========================================================================================\n",
"ResNet -- --\n",
"├─Conv2d (conv1) [64, 3, 32, 32] [64, 64, 16, 16]\n",
"├─BatchNorm2d (bn1) [64, 64, 16, 16] [64, 64, 16, 16]\n",
"├─ReLU (relu) [64, 64, 16, 16] [64, 64, 16, 16]\n",
"├─MaxPool2d (maxpool) [64, 64, 16, 16] [64, 64, 8, 8]\n",
"├─Sequential (layer1) [64, 64, 8, 8] [64, 256, 8, 8]\n",
"│ └─Bottleneck (0) [64, 64, 8, 8] [64, 256, 8, 8]\n",
"│ └─Bottleneck (1) [64, 256, 8, 8] [64, 256, 8, 8]\n",
"│ └─Bottleneck (2) [64, 256, 8, 8] [64, 256, 8, 8]\n",
"├─Sequential (layer2) [64, 256, 8, 8] [64, 512, 4, 4]\n",
"│ └─Bottleneck (0) [64, 256, 8, 8] [64, 512, 4, 4]\n",
"│ └─Bottleneck (1) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"│ └─Bottleneck (2) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"│ └─Bottleneck (3) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"├─Sequential (layer3) [64, 512, 4, 4] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (1) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (2) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (3) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (4) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ └─Bottleneck (5) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"├─Sequential (layer4) [64, 1024, 2, 2] [64, 2048, 1, 1]\n",
"│ └─Bottleneck (0) [64, 1024, 2, 2] [64, 2048, 1, 1]\n",
"│ └─Bottleneck (1) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"│ └─Bottleneck (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"├─AdaptiveAvgPool2d (avgpool) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"├─Linear (fc) [64, 2048] [64, 100]\n",
"==========================================================================================\n",
"Total params: 23,712,932\n",
"Trainable params: 23,712,932\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 5.35\n",
"==========================================================================================\n",
"Input size (MB): 0.79\n",
"Forward/backward pass size (MB): 232.31\n",
"Params size (MB): 94.85\n",
"Estimated Total Size (MB): 327.95\n",
"=========================================================================================="
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"torchinfo.summary(\n",
" model_org,\n",
" (64, 3, 32, 32),\n",
" depth=2,\n",
" col_names=[\"input_size\",\n",
" \"output_size\"],\n",
" row_settings=(\"var_names\",)\n",
" )"
]
},
{
"source": [
"再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"===============================================================================================\n",
"Layer (type (var_name)) Input Shape Output Shape\n",
"===============================================================================================\n",
"ReconstructResNet50 -- --\n",
"├─Sequential (resnet50_bottop_half) [64, 3, 32, 32] [64, 512, 4, 4]\n",
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 64, 16, 16]\n",
"│ └─BatchNorm2d (1) [64, 64, 16, 16] [64, 64, 16, 16]\n",
"│ └─ReLU (2) [64, 64, 16, 16] [64, 64, 16, 16]\n",
"│ └─MaxPool2d (3) [64, 64, 16, 16] [64, 64, 8, 8]\n",
"│ └─Sequential (4) [64, 64, 8, 8] [64, 256, 8, 8]\n",
"│ │ └─Bottleneck (0) [64, 64, 8, 8] [64, 256, 8, 8]\n",
"│ │ └─Bottleneck (1) [64, 256, 8, 8] [64, 256, 8, 8]\n",
"│ │ └─Bottleneck (2) [64, 256, 8, 8] [64, 256, 8, 8]\n",
"│ └─Sequential (5) [64, 256, 8, 8] [64, 512, 4, 4]\n",
"│ │ └─Bottleneck (0) [64, 256, 8, 8] [64, 512, 4, 4]\n",
"│ │ └─Bottleneck (1) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"│ │ └─Bottleneck (2) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"│ │ └─Bottleneck (3) [64, 512, 4, 4] [64, 512, 4, 4]\n",
"├─Sequential (resnet50_top_half) [64, 512, 4, 4] [64, 100]\n",
"│ └─Sequential (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (0) [64, 512, 4, 4] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (1) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (2) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (3) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (4) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ │ └─Bottleneck (5) [64, 1024, 2, 2] [64, 1024, 2, 2]\n",
"│ └─Sequential (1) [64, 1024, 2, 2] [64, 2048, 1, 1]\n",
"│ │ └─Bottleneck (0) [64, 1024, 2, 2] [64, 2048, 1, 1]\n",
"│ │ └─Bottleneck (1) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"│ │ └─Bottleneck (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"│ └─AdaptiveAvgPool2d (2) [64, 2048, 1, 1] [64, 2048, 1, 1]\n",
"│ └─Flatten (3) [64, 2048, 1, 1] [64, 2048]\n",
"│ └─Linear (4) [64, 2048] [64, 100]\n",
"===============================================================================================\n",
"Total params: 23,712,932\n",
"Trainable params: 23,712,932\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 5.35\n",
"===============================================================================================\n",
"Input size (MB): 0.79\n",
"Forward/backward pass size (MB): 232.31\n",
"Params size (MB): 94.85\n",
"Estimated Total Size (MB): 327.95\n",
"==============================================================================================="
]
},
"metadata": {},
"execution_count": 7
}
],
"source": [
"torchinfo.summary(\n",
" model_new,\n",
" (64, 3, 32, 32),\n",
" depth=3,\n",
" col_names=[\"input_size\",\n",
" \"output_size\"],\n",
" row_settings=(\"var_names\",)\n",
" )"
]
},
{
"source": [
"ではデータを流し込んでみる"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data = torch.randn(64, 3, 32, 32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.return_types.max(\nvalues=tensor([1.3454, 1.1356, 1.2690, 1.3834, 1.0672, 1.2274, 1.0207, 1.0614, 1.3978,\n 1.2313, 1.4156, 1.4666, 1.2339, 1.0205, 0.9802, 1.0208, 1.2230, 1.1357,\n 1.1173, 1.1000, 1.2028, 1.2953, 1.0749, 1.1699, 1.2398, 1.2197, 1.3249,\n 1.3639, 1.3710, 1.4766, 1.0805, 1.2381, 1.3166, 1.3227, 1.2800, 1.3676,\n 1.1806, 1.2021, 1.2806, 1.2725, 1.0955, 1.3845, 1.3075, 1.1250, 1.2184,\n 1.4227, 1.0932, 1.1074, 1.1536, 1.1790, 1.3503, 1.3610, 1.2805, 1.2181,\n 1.3869, 1.2473, 1.2844, 1.2438, 1.0619, 1.3060, 1.1839, 1.2812, 1.1054,\n 1.1871], device='cuda:0', grad_fn=<MaxBackward0>),\nindices=tensor([90, 78, 78, 78, 78, 44, 88, 57, 44, 78, 78, 78, 78, 78, 27, 88, 44, 44,\n 27, 44, 90, 78, 44, 90, 27, 44, 44, 78, 78, 44, 27, 78, 78, 78, 44, 78,\n 78, 27, 90, 78, 84, 44, 44, 78, 78, 78, 78, 27, 78, 44, 44, 44, 90, 78,\n 44, 90, 78, 78, 84, 78, 44, 44, 27, 90], device='cuda:0'))\ntensor([[-0.1791, 0.0600, 0.2247, -0.0900, -0.6020],\n [-0.4647, 0.0351, 0.2943, -0.2365, -0.3093],\n [-0.4240, -0.1189, 0.5508, -0.4110, -0.5454],\n [-0.6184, 0.0310, 0.0049, -0.0531, -0.7490],\n [-0.3527, 0.1487, 0.3425, -0.0758, -0.5206],\n [-0.3394, 0.2014, 0.3621, -0.1127, -0.3296],\n [-0.6504, 0.2399, 0.4901, -0.1620, -0.3928],\n [-0.4113, 0.3839, 0.2992, -0.3384, -0.4561],\n [-0.4094, 0.1961, 0.1317, 0.1345, -0.6089],\n [-0.2959, 0.3001, 0.2256, -0.4830, -0.5723]], device='cuda:0',\n grad_fn=<SliceBackward>)\n"
]
}
],
"source": [
"model_new.eval()\n",
"print(model_new(data).max(axis=1))\n",
"print(model_new(data)[:10, :5])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.return_types.max(\n",
"values=tensor([1.3454, 1.1356, 1.2690, 1.3834, 1.0672, 1.2274, 1.0207, 1.0614, 1.3978,\n",
" 1.2313, 1.4156, 1.4666, 1.2339, 1.0205, 0.9802, 1.0208, 1.2230, 1.1357,\n",
" 1.1173, 1.1000, 1.2028, 1.2953, 1.0749, 1.1699, 1.2398, 1.2197, 1.3249,\n",
" 1.3639, 1.3710, 1.4766, 1.0805, 1.2381, 1.3166, 1.3227, 1.2800, 1.3676,\n",
" 1.1806, 1.2021, 1.2806, 1.2725, 1.0955, 1.3845, 1.3075, 1.1250, 1.2184,\n",
" 1.4227, 1.0932, 1.1074, 1.1536, 1.1790, 1.3503, 1.3610, 1.2805, 1.2181,\n",
" 1.3869, 1.2473, 1.2844, 1.2438, 1.0619, 1.3060, 1.1839, 1.2812, 1.1054,\n",
" 1.1871], device='cuda:0', grad_fn=<MaxBackward0>),\n",
"indices=tensor([90, 78, 78, 78, 78, 44, 88, 57, 44, 78, 78, 78, 78, 78, 27, 88, 44, 44,\n",
" 27, 44, 90, 78, 44, 90, 27, 44, 44, 78, 78, 44, 27, 78, 78, 78, 44, 78,\n",
" 78, 27, 90, 78, 84, 44, 44, 78, 78, 78, 78, 27, 78, 44, 44, 44, 90, 78,\n",
" 44, 90, 78, 78, 84, 78, 44, 44, 27, 90], device='cuda:0'))\n",
"tensor([[-0.1791, 0.0600, 0.2247, -0.0900, -0.6020],\n",
" [-0.4647, 0.0351, 0.2943, -0.2365, -0.3093],\n",
" [-0.4240, -0.1189, 0.5508, -0.4110, -0.5454],\n",
" [-0.6184, 0.0310, 0.0049, -0.0531, -0.7490],\n",
" [-0.3527, 0.1487, 0.3425, -0.0758, -0.5206],\n",
" [-0.3394, 0.2014, 0.3621, -0.1127, -0.3296],\n",
" [-0.6504, 0.2399, 0.4901, -0.1620, -0.3928],\n",
" [-0.4113, 0.3839, 0.2992, -0.3384, -0.4561],\n",
" [-0.4094, 0.1961, 0.1317, 0.1345, -0.6089],\n",
" [-0.2959, 0.3001, 0.2256, -0.4830, -0.5723]], device='cuda:0',\n",
" grad_fn=<SliceBackward>)\n"
]
}
],
"source": [
"model_org.eval()\n",
"print(model_org(data).max(axis=1))\n",
"print(model_org(data)[:10, :5])"
]
},
{
"source": [
"どちらも同じ出力が得られたことが分かる.\n",
"\n",
"ではoptimizerを設定して,勾配も一致するかどうか確認する."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
"optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"target = torch.randint(num_classes, (64,)).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(4.7663, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"model_new.train()\n",
"optimizer_new.zero_grad()\n",
"output_new = model_new(data)\n",
"loss_new = criterion(output_new, target)\n",
"loss_new.backward()\n",
"loss_new"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(4.7663, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"model_org.train()\n",
"optimizer_org.zero_grad()\n",
"output_org = model_org(data)\n",
"loss_org = criterion(output_org, target)\n",
"loss_org.backward()\n",
"loss_org"
]
},
{
"source": [
"lossは一致した.\n",
"では重みとその勾配は一致するかどうかを確認する.対象は最初のconvの5x5に限定."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0133, 0.0147, -0.0154, -0.0230, -0.0409],\n",
" [ 0.0041, 0.0058, 0.0149, 0.0206, 0.0022],\n",
" [ 0.0223, 0.0236, 0.0161, 0.0588, 0.1028],\n",
" [ 0.0232, 0.0042, -0.0459, -0.0487, -0.0164],\n",
" [-0.0009, 0.0278, -0.0101, -0.0554, -0.1272]], device='cuda:0',\n",
" grad_fn=<SliceBackward>)"
]
},
"metadata": {},
"execution_count": 15
}
],
"source": [
"model_new.resnet50_bottop_half[0].weight[0, 0, :5, :5]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0133, 0.0147, -0.0154, -0.0230, -0.0409],\n",
" [ 0.0041, 0.0058, 0.0149, 0.0206, 0.0022],\n",
" [ 0.0223, 0.0236, 0.0161, 0.0588, 0.1028],\n",
" [ 0.0232, 0.0042, -0.0459, -0.0487, -0.0164],\n",
" [-0.0009, 0.0278, -0.0101, -0.0554, -0.1272]], device='cuda:0',\n",
" grad_fn=<SliceBackward>)"
]
},
"metadata": {},
"execution_count": 16
}
],
"source": [
"model_org.conv1.weight[0, 0, :5, :5]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0495, -0.0230, 0.0685, 0.0177, -0.0276],\n",
" [ 0.0195, -0.0643, 0.0361, 0.0360, -0.0751],\n",
" [ 0.0223, 0.0310, 0.0489, 0.0656, -0.1489],\n",
" [-0.0438, 0.1419, 0.0153, -0.1075, 0.1269],\n",
" [-0.1402, -0.0279, -0.1216, 0.1160, 0.0147]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 17
}
],
"source": [
"model_new.resnet50_bottop_half[0].weight.grad[0, 0, :5, :5]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0495, -0.0230, 0.0685, 0.0177, -0.0276],\n",
" [ 0.0195, -0.0643, 0.0361, 0.0360, -0.0751],\n",
" [ 0.0223, 0.0310, 0.0489, 0.0656, -0.1489],\n",
" [-0.0438, 0.1419, 0.0153, -0.1075, 0.1269],\n",
" [-0.1402, -0.0279, -0.1216, 0.1160, 0.0147]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 18
}
],
"source": [
"model_org.conv1.weight.grad[0, 0, :5, :5]"
]
},
{
"source": [
"これで一致したことが確認できた.\n",
"\n",
"では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class ABNResNet50(nn.Module):\n",
" def __init__(self, num_classes: int = 1000) -> None:\n",
" super().__init__()\n",
" model = models.resnet50(pretrained=True)\n",
"\n",
" self.resnet50_bottop = nn.Sequential(\n",
" model.conv1,\n",
" model.bn1,\n",
" model.relu,\n",
" model.maxpool,\n",
" model.layer1,\n",
" model.layer2\n",
" )\n",
"\n",
" self.resnet50_top = nn.Sequential(\n",
" model.layer3,\n",
" model.layer4,\n",
" model.avgpool,\n",
" nn.Flatten(),\n",
" nn.Linear(model.fc.in_features, num_classes)\n",
" )\n",
"\n",
" self.attention_branch1 = nn.Sequential(\n",
" copy.deepcopy(model.layer2[3]), # deepcopyしないと,上で使ったものと重みが共有されてしまう\n",
" nn.BatchNorm2d(512),\n",
" nn.Conv2d(512, 512, kernel_size=1),\n",
" nn.ReLU(inplace=True)\n",
" )\n",
" self.attention_branch2 = nn.Sequential(\n",
" nn.Conv2d(512, 512, kernel_size=1),\n",
" nn.BatchNorm2d(512),\n",
" nn.Sigmoid()\n",
" )\n",
" self.attention_branch3 = nn.Sequential(\n",
" nn.Conv2d(512, num_classes, kernel_size=1),\n",
" )\n",
"\n",
" def get_attn(self):\n",
" return self.attn\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" x = self.resnet50_bottop(x)\n",
"\n",
" ax = self.attention_branch1(x)\n",
" attn = self.attention_branch2(ax)\n",
"\n",
" x = x * attn\n",
" x = self.resnet50_top(x)\n",
"\n",
" ax = self.attention_branch3(ax)\n",
" ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze()\n",
"\n",
" self.attn = attn\n",
" return x, ax, attn\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"model_new3 = ABNResNet50(num_classes=num_classes).to(device)\n",
"model_new3 = nn.DataParallel(model_new3)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"optimizer_new3 = torch.optim.SGD(model_new3.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
"# optimizer_new3 = torch.optim.Adam(model_new3.parameters(), lr=0.1, weight_decay=5e-4)"
]
},
{
"source": [
"便利関数を作っておく."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"class AverageMeter(object):\n",
" \"\"\"\n",
" Computes and stores the average and current value\n",
" Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262\n",
" https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59\n",
" \"\"\"\n",
" def __init__(self):\n",
" self.reset()\n",
"\n",
" def reset(self):\n",
" self.val = 0\n",
" self.avg = 0\n",
" self.sum = 0\n",
" self.count = 0\n",
"\n",
" def update(self, val, n=1):\n",
" if type(val) == torch.Tensor:\n",
" val = val.item()\n",
" self.val = val\n",
" self.sum += val * n\n",
" self.count += n\n",
" self.avg = self.sum / self.count\n",
"\n",
"def top1(outputs, targets):\n",
" batch_size = outputs.size(0)\n",
" _, predicted = outputs.max(1)\n",
" return predicted.eq(targets).sum().item() / batch_size"
]
},
{
"source": [
"CIFAR100で学習"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "aecc7007ad984f12aedf6ad3cd5ed2b6"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "9913ee6cf57449f0a4dd65b09051205c"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7cf38e5bf4084979ba124aae7f57e0d9"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "17763f0424d84d0ba426206c66fb8d28"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7f64b12b947349858bb4f3337d6ebccb"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "f5ea7701afc043d697238db42345879e"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "9689ab0caf904f49afab76134c39aab9"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "90d2e07b71e241c89318c00505929693"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "b61432c10ec748babaa6405f35f8d407"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "a2592030922d4c68bac0d91ef13b1dab"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "22759d1dcea1473b8c2e1e01df558ba3"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n\n"
]
}
],
"source": [
"num_epochs = 5\n",
"\n",
"with tqdm(range(num_epochs)) as pbar_epoch:\n",
" for epoch in pbar_epoch:\n",
" pbar_epoch.set_description(\"[Epoch %d]\" % (epoch))\n",
"\n",
"\n",
" with tqdm(enumerate(train_loader),\n",
" total=len(train_loader),\n",
" leave=True) as pbar_loss:\n",
"\n",
" train_loss_new3 = AverageMeter()\n",
" train_acc_new3 = AverageMeter()\n",
" correct_new3 = AverageMeter()\n",
" model_new3.train()\n",
"\n",
" for batch_idx, (inputs, targets) in pbar_loss:\n",
" pbar_loss.set_description(\"[train]\")\n",
"\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
" bs = inputs.size(0) # current batch size, may vary at the end of the epoch\n",
"\n",
" optimizer_new3.zero_grad()\n",
" outputs_x_new3, outputs_ax_new3, _ = model_new3(inputs)\n",
" loss_x_new3 = criterion(outputs_x_new3, targets)\n",
" loss_ax_new3 = criterion(outputs_ax_new3, targets)\n",
" loss_new3 = (loss_x_new3 + loss_ax_new3) / 2\n",
" loss_new3.backward()\n",
" optimizer_new3.step()\n",
" train_loss_new3.update(loss_new3, bs)\n",
" train_acc_new3.update(top1(outputs_x_new3, targets), bs)\n",
"\n",
" pbar_loss.set_postfix_str(\n",
" ' | loss={:6.04f} , top1={:6.04f}'\n",
" ''.format(\n",
" train_loss_new3.avg, train_acc_new3.avg,\n",
" ))\n",
"\n",
"\n",
" with torch.no_grad(), \\\n",
" tqdm(enumerate(val_loader),\n",
" total=len(val_loader),\n",
" leave=True) as pbar_loss:\n",
"\n",
" train_loss_new3 = AverageMeter()\n",
" train_acc_new3 = AverageMeter()\n",
" correct_new3 = AverageMeter()\n",
" model_new3.eval()\n",
"\n",
" for batch_idx, (inputs, targets) in pbar_loss:\n",
" pbar_loss.set_description(\"[val]\")\n",
"\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
"\n",
" # optimizer_new3.zero_grad()\n",
" outputs_x_new3, outputs_ax_new3, _ = model_new3(inputs)\n",
" loss_x_new3 = criterion(outputs_x_new3, targets)\n",
" loss_ax_new3 = criterion(outputs_ax_new3, targets)\n",
" loss_new3 = (loss_x_new3 + loss_ax_new3) / 2\n",
" # loss_new3.backward()\n",
" # optimizer_new3.step()\n",
" train_loss_new3.update(loss_new3, bs)\n",
" train_acc_new3.update(top1(outputs_x_new3, targets), bs)\n",
"\n",
"\n",
"\n",
" pbar_loss.set_postfix_str(\n",
" ' | loss={:6.04f} , top1={:6.04f}'\n",
" ''.format(\n",
" train_loss_new3.avg, train_acc_new3.avg,\n",
" ))\n",
"\n",
"\n",
"\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment