Skip to content

Instantly share code, notes, and snippets.

@tttamaki
Created August 30, 2021 08:25
Show Gist options
  • Save tttamaki/f839c9aa26b0174c85aa4359d8971a0c to your computer and use it in GitHub Desktop.
Save tttamaki/f839c9aa26b0174c85aa4359d8971a0c to your computer and use it in GitHub Desktop.
事前学習済みResNetを切り貼りしてABNを作ってみる
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"事前学習済みのResNet50をバラバラにして,ABNを作ってみる."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"!pip install torchinfo\n",
"!pip install git+https://github.com/facebookresearch/fvcore.git"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: torchinfo in /opt/conda/lib/python3.8/site-packages (1.5.2)\n",
"\u001b[33mWARNING: You are using pip version 21.1.3; however, version 21.2.4 is available.\n",
"You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Collecting git+https://github.com/facebookresearch/fvcore.git\n",
" Cloning https://github.com/facebookresearch/fvcore.git to /tmp/pip-req-build-khia8ald\n",
" Running command git clone -q https://github.com/facebookresearch/fvcore.git /tmp/pip-req-build-khia8ald\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (1.20.1)\n",
"Requirement already satisfied: yacs>=0.1.6 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.1.8)\n",
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (5.4.1)\n",
"Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (4.53.0)\n",
"Requirement already satisfied: termcolor>=1.1 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (1.1.0)\n",
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (8.3.1)\n",
"Requirement already satisfied: tabulate in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.8.9)\n",
"Requirement already satisfied: iopath>=0.1.7 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.1.9)\n",
"Requirement already satisfied: portalocker in /opt/conda/lib/python3.8/site-packages (from iopath>=0.1.7->fvcore==0.1.5) (2.3.0)\n",
"\u001b[33mWARNING: You are using pip version 21.1.3; however, version 21.2.4 is available.\n",
"You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchvision\n",
"from torchvision import transforms\n",
"\n",
"import torchinfo\n",
"import copy\n",
"from tqdm import tqdm\n",
"# from tqdm.notebook import tqdm"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"データローダの準備"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"transform_train = transforms.Compose([\n",
" transforms.RandomCrop(32, padding=1),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n",
" std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"transform_val = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n",
" std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"\n",
"num_classes = 10\n",
"batch_size = 64\n",
"num_workers = 2\n",
"\n",
"train_set = torchvision.datasets.CIFAR10(\n",
" root='./data',\n",
" train=True,\n",
" download=True,\n",
" transform=transform_train\n",
")\n",
"val_set = torchvision.datasets.CIFAR10(\n",
" root='./data',\n",
" train=False,\n",
" download=True,\n",
" transform=transform_val\n",
")\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" train_set, \n",
" batch_size=batch_size,\n",
" shuffle=True, \n",
" num_workers=num_workers)\n",
"val_loader = torch.utils.data.DataLoader(\n",
" val_set,\n",
" batch_size=batch_size,\n",
" shuffle=True, \n",
" num_workers=num_workers)\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"まずはoriginalのResNet50.\n",
"どのモデルでもよいが,とりあえず\n",
"https://github.com/chenyaofo/pytorch-cifar-models\n",
"から,CIFAR10用のpretrainモデルを取得."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model_org = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n",
" \"cifar100_resnet56\", \n",
" pretrained=True)\n",
"model_org = model_org.to(device)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"summaryを確認."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"torchinfo.summary(\n",
" model_org,\n",
" (64, 3, 32, 32),\n",
" depth=2,\n",
" col_names=[\"input_size\",\n",
" \"output_size\",\n",
" \"kernel_size\"],\n",
" row_settings=(\"var_names\",)\n",
" )"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n",
"===================================================================================================================\n",
"CifarResNet -- -- --\n",
"├─Conv2d (conv1) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n",
"├─BatchNorm2d (bn1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n",
"├─ReLU (relu) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"├─Sequential (layer1) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"├─Sequential (layer2) [64, 16, 32, 32] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"├─Sequential (layer3) [64, 32, 16, 16] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"├─AdaptiveAvgPool2d (avgpool) [64, 64, 8, 8] [64, 64, 1, 1] --\n",
"├─Linear (fc) [64, 64] [64, 100] [64, 100]\n",
"===================================================================================================================\n",
"Total params: 861,620\n",
"Trainable params: 861,620\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 8.05\n",
"===================================================================================================================\n",
"Input size (MB): 0.79\n",
"Forward/backward pass size (MB): 557.89\n",
"Params size (MB): 3.45\n",
"Estimated Total Size (MB): 562.13\n",
"==================================================================================================================="
]
},
"metadata": {},
"execution_count": 5
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"class ReconstructResNet50(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" model = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n",
" \"cifar100_resnet56\", \n",
" pretrained=True)\n",
"\n",
" self.resnet50_bottom_half = nn.Sequential(\n",
" model.conv1,\n",
" model.bn1,\n",
" model.relu,\n",
" model.layer1,\n",
" model.layer2,\n",
" model.layer3,\n",
" )\n",
"\n",
" self.resnet50_top_half = nn.Sequential(\n",
" model.avgpool,\n",
" nn.Flatten(),\n",
" model.fc\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" x = self.resnet50_bottom_half(x)\n",
" x = self.resnet50_top_half(x)\n",
" return x"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"model_new = ReconstructResNet50()\n",
"model_new = model_new.to(device)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"torchinfo.summary(\n",
" model_new,\n",
" (64, 3, 32, 32),\n",
" depth=3,\n",
" col_names=[\"input_size\",\n",
" \"output_size\",\n",
" \"kernel_size\"],\n",
" row_settings=(\"var_names\",)\n",
" )"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"========================================================================================================================\n",
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n",
"========================================================================================================================\n",
"ReconstructResNet50 -- -- --\n",
"├─Sequential (resnet50_bottom_half) [64, 3, 32, 32] [64, 64, 8, 8] --\n",
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n",
"│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n",
"│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ │ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ │ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --\n",
"│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ │ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"├─Sequential (resnet50_top_half) [64, 64, 8, 8] [64, 100] --\n",
"│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --\n",
"│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --\n",
"│ └─Linear (2) [64, 64] [64, 100] [64, 100]\n",
"========================================================================================================================\n",
"Total params: 861,620\n",
"Trainable params: 861,620\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 8.05\n",
"========================================================================================================================\n",
"Input size (MB): 0.79\n",
"Forward/backward pass size (MB): 557.89\n",
"Params size (MB): 3.45\n",
"Estimated Total Size (MB): 562.13\n",
"========================================================================================================================"
]
},
"metadata": {},
"execution_count": 8
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"ではデータを流し込んでみる.\n"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"data = torch.randn(64, 3, 32, 32).to(device)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"model_new.eval()\n",
"print(model_new(data).max(axis=1))\n",
"print(model_new(data)[:10, :5])"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.return_types.max(\n",
"values=tensor([10.9443, 12.4955, 13.0963, 10.5355, 10.8972, 11.2695, 10.1008, 11.6357,\n",
" 13.1638, 11.4222, 10.9982, 10.4561, 10.5532, 10.5805, 12.4815, 12.0393,\n",
" 13.4050, 10.7294, 10.4074, 11.0421, 12.0968, 11.0542, 10.8104, 13.8886,\n",
" 10.2925, 10.4051, 13.7182, 11.5399, 11.1024, 10.0979, 11.6706, 10.5574,\n",
" 12.0685, 13.5889, 14.2758, 12.7623, 11.5654, 9.5789, 11.1102, 10.3502,\n",
" 11.3670, 11.1035, 11.8346, 12.4422, 10.3109, 10.8246, 10.7729, 10.2088,\n",
" 11.5740, 11.5479, 11.4090, 12.2235, 10.6360, 14.7207, 10.1188, 10.0799,\n",
" 10.8706, 11.0431, 12.1793, 10.3579, 9.0835, 11.4152, 9.4510, 10.4705],\n",
" device='cuda:0', grad_fn=<MaxBackward0>),\n",
"indices=tensor([10, 82, 10, 57, 57, 57, 10, 57, 82, 83, 57, 57, 57, 10, 10, 10, 10, 57,\n",
" 82, 10, 10, 10, 82, 57, 10, 10, 83, 10, 57, 10, 10, 57, 10, 10, 10, 10,\n",
" 57, 57, 82, 10, 10, 57, 10, 57, 57, 57, 10, 82, 83, 9, 78, 10, 82, 10,\n",
" 10, 10, 57, 10, 57, 57, 10, 10, 82, 10], device='cuda:0'))\n",
"tensor([[ 6.0794e+00, -1.1938e-01, -1.2621e-01, -5.1893e-02, -4.3699e+00],\n",
" [ 6.5469e+00, -2.5111e-01, -6.4007e-01, -2.0096e+00, -3.5481e+00],\n",
" [ 5.7094e+00, 3.2534e-01, 5.0513e-01, 1.5728e-01, -4.6752e+00],\n",
" [ 6.8923e+00, 9.4589e-01, 1.2946e+00, 1.6836e+00, -2.7677e+00],\n",
" [ 6.2214e+00, 1.3083e+00, -1.7860e-01, 3.7956e-01, -4.4028e+00],\n",
" [ 5.7689e+00, 3.6209e-01, -1.4508e+00, 4.5521e-01, -4.2887e+00],\n",
" [ 5.1389e+00, -4.2951e-01, -9.8702e-01, 2.5621e-02, -3.7350e+00],\n",
" [ 5.1116e+00, 1.6931e+00, 6.2985e-04, -1.1464e+00, -4.3173e+00],\n",
" [ 5.7331e+00, -3.3522e-01, -1.2759e+00, -4.5809e-01, -3.7536e+00],\n",
" [ 4.2918e+00, 2.0628e+00, 7.2408e-02, 2.6386e-01, -4.4757e+00]],\n",
" device='cuda:0', grad_fn=<SliceBackward>)\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"model_org.eval()\n",
"print(model_org(data).max(axis=1))\n",
"print(model_org(data)[:10, :5])"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.return_types.max(\n",
"values=tensor([10.9443, 12.4955, 13.0963, 10.5355, 10.8972, 11.2695, 10.1008, 11.6357,\n",
" 13.1638, 11.4222, 10.9982, 10.4561, 10.5532, 10.5805, 12.4815, 12.0393,\n",
" 13.4050, 10.7294, 10.4074, 11.0421, 12.0968, 11.0542, 10.8104, 13.8886,\n",
" 10.2925, 10.4051, 13.7182, 11.5399, 11.1024, 10.0979, 11.6706, 10.5574,\n",
" 12.0685, 13.5889, 14.2758, 12.7623, 11.5654, 9.5789, 11.1102, 10.3502,\n",
" 11.3670, 11.1035, 11.8346, 12.4422, 10.3109, 10.8246, 10.7729, 10.2088,\n",
" 11.5740, 11.5479, 11.4090, 12.2235, 10.6360, 14.7207, 10.1188, 10.0799,\n",
" 10.8706, 11.0431, 12.1793, 10.3579, 9.0835, 11.4152, 9.4510, 10.4705],\n",
" device='cuda:0', grad_fn=<MaxBackward0>),\n",
"indices=tensor([10, 82, 10, 57, 57, 57, 10, 57, 82, 83, 57, 57, 57, 10, 10, 10, 10, 57,\n",
" 82, 10, 10, 10, 82, 57, 10, 10, 83, 10, 57, 10, 10, 57, 10, 10, 10, 10,\n",
" 57, 57, 82, 10, 10, 57, 10, 57, 57, 57, 10, 82, 83, 9, 78, 10, 82, 10,\n",
" 10, 10, 57, 10, 57, 57, 10, 10, 82, 10], device='cuda:0'))\n",
"tensor([[ 6.0794e+00, -1.1938e-01, -1.2621e-01, -5.1893e-02, -4.3699e+00],\n",
" [ 6.5469e+00, -2.5111e-01, -6.4007e-01, -2.0096e+00, -3.5481e+00],\n",
" [ 5.7094e+00, 3.2534e-01, 5.0513e-01, 1.5728e-01, -4.6752e+00],\n",
" [ 6.8923e+00, 9.4589e-01, 1.2946e+00, 1.6836e+00, -2.7677e+00],\n",
" [ 6.2214e+00, 1.3083e+00, -1.7860e-01, 3.7956e-01, -4.4028e+00],\n",
" [ 5.7689e+00, 3.6209e-01, -1.4508e+00, 4.5521e-01, -4.2887e+00],\n",
" [ 5.1389e+00, -4.2951e-01, -9.8702e-01, 2.5621e-02, -3.7350e+00],\n",
" [ 5.1116e+00, 1.6931e+00, 6.2985e-04, -1.1464e+00, -4.3173e+00],\n",
" [ 5.7331e+00, -3.3522e-01, -1.2759e+00, -4.5809e-01, -3.7536e+00],\n",
" [ 4.2918e+00, 2.0628e+00, 7.2408e-02, 2.6386e-01, -4.4757e+00]],\n",
" device='cuda:0', grad_fn=<SliceBackward>)\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"どちらも同じ出力が得られたことが分かる.\n",
"\n",
"ではoptimizerを設定して,勾配も一致するかどうか確認する."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
"optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"target = torch.randint(num_classes, (64,)).to(device)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"source": [
"model_org.train()\n",
"optimizer_org.zero_grad()\n",
"output_org = model_org(data)\n",
"loss_org = criterion(output_org, target)\n",
"loss_org.backward()\n",
"loss_org"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(8.3755, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"metadata": {},
"execution_count": 14
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"source": [
"model_new.train()\n",
"optimizer_new.zero_grad()\n",
"output_new = model_new(data)\n",
"loss_new = criterion(output_new, target)\n",
"loss_new.backward()\n",
"loss_new"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(8.3755, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"metadata": {},
"execution_count": 15
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"lossは一致した.\n",
"では重みとその勾配は一致するかどうかを確認する.対象は最初のconvに限定."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"source": [
"model_new.resnet50_bottom_half[0].weight[0, 0, :, :]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0220, -0.1125, -0.0150],\n",
" [-0.1283, -0.2929, -0.1424],\n",
" [-0.0580, -0.1338, -0.0578]], device='cuda:0', grad_fn=<SliceBackward>)"
]
},
"metadata": {},
"execution_count": 16
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 17,
"source": [
"model_org.conv1.weight[0, 0, :, :]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0220, -0.1125, -0.0150],\n",
" [-0.1283, -0.2929, -0.1424],\n",
" [-0.0580, -0.1338, -0.0578]], device='cuda:0', grad_fn=<SliceBackward>)"
]
},
"metadata": {},
"execution_count": 17
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 18,
"source": [
"model_new.resnet50_bottom_half[0].weight.grad[0, 0, :, :]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1526, 0.0275, -0.3145],\n",
" [-0.4756, 0.3217, 0.0965],\n",
" [-0.5709, -0.3611, -0.4316]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 18
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 19,
"source": [
"model_org.conv1.weight.grad[0, 0, :, :]"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1526, 0.0275, -0.3145],\n",
" [-0.4756, 0.3217, 0.0965],\n",
" [-0.5709, -0.3611, -0.4316]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 19
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"これで一致したことが確認できた.\n",
"\n",
"では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"source": [
"class ABNResNet50(nn.Module):\n",
" def __init__(self, num_classes=100, pretrained=True):\n",
" super().__init__()\n",
" model = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n",
" \"cifar100_resnet56\", \n",
" pretrained=pretrained)\n",
"\n",
" self.resnet50_bottom = nn.Sequential(\n",
" model.conv1,\n",
" model.bn1,\n",
" model.relu,\n",
" model.layer1,\n",
" model.layer2,\n",
" model.layer3,\n",
" )\n",
" r50b_out_features = \\\n",
" list(self.resnet50_bottom.modules())[-1].num_features\n",
"\n",
" self.resnet50_top = nn.Sequential(\n",
" model.avgpool,\n",
" nn.Flatten(),\n",
" model.fc\n",
" )\n",
"\n",
" self.attention_branch1 = nn.Sequential(\n",
" # ここは入出力サイズが同じlayer3[1:8]を再利用してしまおう\n",
" # deepcopyしないと,上で使ったものと重みが共有されてしまうので注意\n",
" copy.deepcopy(model.layer3[1:8]),\n",
"\n",
" nn.BatchNorm2d(r50b_out_features),\n",
" nn.Conv2d(r50b_out_features, num_classes, kernel_size=1),\n",
" nn.ReLU(inplace=True)\n",
" )\n",
" self.attention_branch2 = nn.Sequential(\n",
" nn.Conv2d(num_classes, 1, kernel_size=1),\n",
" nn.BatchNorm2d(1),\n",
" nn.Sigmoid()\n",
" )\n",
" self.attention_branch3 = nn.Sequential(\n",
" nn.Conv2d(num_classes, num_classes, kernel_size=1),\n",
" )\n",
"\n",
" def get_attn(self):\n",
" return self.attn\n",
"\n",
" def forward(self, x):\n",
" x = self.resnet50_bottom(x)\n",
"\n",
" ax = self.attention_branch1(x)\n",
" attn = self.attention_branch2(ax)\n",
"\n",
" x = x * attn\n",
" x = self.resnet50_top(x)\n",
"\n",
" ax = self.attention_branch3(ax)\n",
" ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze() # GAP\n",
"\n",
" self.attn = attn\n",
" return x, ax, attn\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 21,
"source": [
"model = ABNResNet50(pretrained=True)\n",
"model.to(device)\n",
"model.train()\n",
"\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"ABNのパラメータを確認."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 22,
"source": [
"torchinfo.summary(\n",
" model,\n",
" (64, 3, 32, 32),\n",
" depth=2,\n",
" col_names=[\"input_size\",\n",
" \"output_size\",\n",
" \"kernel_size\"],\n",
" row_settings=(\"var_names\",)\n",
" )"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"========================================================================================================================\n",
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n",
"========================================================================================================================\n",
"ABNResNet50 -- -- --\n",
"├─Sequential (resnet50_bottom) [64, 3, 32, 32] [64, 64, 8, 8] --\n",
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n",
"│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n",
"│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n",
"│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --\n",
"│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --\n",
"├─Sequential (attention_branch1) [64, 64, 8, 8] [64, 100, 8, 8] --\n",
"│ └─Sequential (0) [64, 64, 8, 8] [64, 64, 8, 8] --\n",
"│ └─BatchNorm2d (1) [64, 64, 8, 8] [64, 64, 8, 8] [64]\n",
"│ └─Conv2d (2) [64, 64, 8, 8] [64, 100, 8, 8] [64, 100, 1, 1]\n",
"│ └─ReLU (3) [64, 100, 8, 8] [64, 100, 8, 8] --\n",
"├─Sequential (attention_branch2) [64, 100, 8, 8] [64, 1, 8, 8] --\n",
"│ └─Conv2d (0) [64, 100, 8, 8] [64, 1, 8, 8] [100, 1, 1, 1]\n",
"│ └─BatchNorm2d (1) [64, 1, 8, 8] [64, 1, 8, 8] [1]\n",
"│ └─Sigmoid (2) [64, 1, 8, 8] [64, 1, 8, 8] --\n",
"├─Sequential (resnet50_top) [64, 64, 8, 8] [64, 100] --\n",
"│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --\n",
"│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --\n",
"│ └─Linear (2) [64, 64] [64, 100] [64, 100]\n",
"├─Sequential (attention_branch3) [64, 100, 8, 8] [64, 100, 8, 8] --\n",
"│ └─Conv2d (0) [64, 100, 8, 8] [64, 100, 8, 8] [100, 100, 1, 1]\n",
"========================================================================================================================\n",
"Total params: 1,396,339\n",
"Trainable params: 1,396,339\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 10.23\n",
"========================================================================================================================\n",
"Input size (MB): 0.79\n",
"Forward/backward pass size (MB): 625.33\n",
"Params size (MB): 5.59\n",
"Estimated Total Size (MB): 631.70\n",
"========================================================================================================================"
]
},
"metadata": {},
"execution_count": 22
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"便利関数を作っておく."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"source": [
"class AverageMeter(object):\n",
" \"\"\"\n",
" Computes and stores the average and current value\n",
" Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380\n",
" https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59\n",
" \"\"\"\n",
" def __init__(self):\n",
" self.reset()\n",
"\n",
" def reset(self):\n",
" self.val = 0\n",
" self.avg = 0\n",
" self.sum = 0\n",
" self.count = 0\n",
"\n",
" def update(self, val, n=1):\n",
" if type(val) == torch.Tensor:\n",
" val = val.item()\n",
" self.val = val\n",
" self.sum += val * n\n",
" self.count += n\n",
" self.avg = self.sum / self.count\n",
"\n",
"\n",
"def accuracy(output, target, topk=(1,)):\n",
" \"\"\"\n",
" Computes the accuracy over the k top predictions for the specified values of k\n",
" https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" maxk = max(topk)\n",
" batch_size = target.size(0)\n",
"\n",
" _, pred = output.topk(maxk, 1, True, True)\n",
" pred = pred.t()\n",
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
"\n",
" res = []\n",
" for k in topk:\n",
" correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
" res.append(correct_k.mul_(100.0 / batch_size))\n",
" return res\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 24,
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"epoch_num = 10"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"では学習と評価."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 25,
"source": [
"def one_epoch_process(data_loader, is_train=True):\n",
" \"\"\"train or evaluation for one epoch\n",
"\n",
" Args:\n",
" data_loader (DataLoader): data loader of train or val set\n",
" model (nn.Module): CNN model\n",
" is_train (bool, optional): flag of train or val. Defaults to True.\n",
" \"\"\"\n",
" with tqdm(enumerate(data_loader),\n",
" total=len(data_loader),\n",
" leave=True) as pbar_loss:\n",
"\n",
" log_loss = AverageMeter()\n",
" log_top1 = AverageMeter()\n",
" log_top5 = AverageMeter()\n",
" correct = AverageMeter()\n",
" model.train()\n",
" \n",
" for batch_idx, (image, label) in pbar_loss:\n",
" pbar_loss.set_description(\"[{}]\".format('train' if is_train else 'val'))\n",
"\n",
" current_batch_size = image.size(0)\n",
"\n",
" image, label = image.to(device), label.to(device)\n",
"\n",
" y, ay, _ = model(image)\n",
"\n",
" loss = criterion(y, label)\n",
" loss_ay = criterion(ay, label)\n",
" loss_all = (loss + loss_ay) / 2\n",
" log_loss.update(loss_all, current_batch_size)\n",
"\n",
" if is_train:\n",
" optimizer.zero_grad()\n",
" loss_all.backward()\n",
" optimizer.step()\n",
" \n",
" acc1, acc5 = accuracy(y, label, topk=(1, 5))\n",
" log_top1.update(acc1, current_batch_size)\n",
" log_top5.update(acc5, current_batch_size)\n",
"\n",
" pbar_loss.set_postfix_str(\n",
" ' | loss={:6.04f} acc top1={:6.04f} top5={:6.04f}'\n",
" ' err top1={:6.04f} top5={:6.04f}'\n",
" ''.format(\n",
" log_loss.avg, \n",
" log_top1.avg,\n",
" log_top5.avg,\n",
" 100 - log_top1.avg,\n",
" 100 - log_top5.avg,\n",
" ))\n",
"\n",
"\n",
"def train():\n",
" one_epoch_process(train_loader, is_train=True)\n",
"\n",
"def val():\n",
" with torch.no_grad():\n",
" one_epoch_process(val_loader, is_train=False)\n",
"\n",
"\n",
"with tqdm(range(epoch_num)) as pbar_epoch:\n",
" for epoch in pbar_epoch:\n",
" pbar_epoch.set_description(\"[Epoch %d]\" % (epoch))\n",
"\n",
" train()\n",
"\n",
" val()\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[train]: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s, | loss=0.7773 acc top1=78.2380 top5=97.6020 err top1=21.7620 top5=2.3980]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.31it/s, | loss=0.5266 acc top1=82.5600 top5=99.2100 err top1=17.4400 top5=0.7900]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.4180 acc top1=85.9420 top5=99.4820 err top1=14.0580 top5=0.5180]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.71it/s, | loss=0.4284 acc top1=85.2400 top5=99.4400 err top1=14.7600 top5=0.5600]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.82it/s, | loss=0.3465 acc top1=88.2740 top5=99.6420 err top1=11.7260 top5=0.3580]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.58it/s, | loss=0.4238 acc top1=85.8700 top5=99.4100 err top1=14.1300 top5=0.5900]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.98it/s, | loss=0.3014 acc top1=89.8680 top5=99.6780 err top1=10.1320 top5=0.3220]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.73it/s, | loss=0.3660 acc top1=87.6900 top5=99.6300 err top1=12.3100 top5=0.3700]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.2664 acc top1=90.9200 top5=99.7340 err top1=9.0800 top5=0.2660]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 31.18it/s, | loss=0.3526 acc top1=88.3100 top5=99.6000 err top1=11.6900 top5=0.4000]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2407 acc top1=91.9240 top5=99.8040 err top1=8.0760 top5=0.1960]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.45it/s, | loss=0.3370 acc top1=89.0900 top5=99.6700 err top1=10.9100 top5=0.3300]\n",
"[train]: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s, | loss=0.2209 acc top1=92.5220 top5=99.8440 err top1=7.4780 top5=0.1560]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.61it/s, | loss=0.2876 acc top1=90.2700 top5=99.7300 err top1=9.7300 top5=0.2700]\n",
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2077 acc top1=92.9760 top5=99.8880 err top1=7.0240 top5=0.1120]\n",
"[val]: 100%|██████████| 157/157 [00:05<00:00, 31.10it/s, | loss=0.3089 acc top1=90.0600 top5=99.7000 err top1=9.9400 top5=0.3000]\n",
"[Epoch 8]: 80%|████████ | 8/10 [08:45<02:11, 65.62s/it]"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.8.8",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.8 64-bit ('base': conda)"
},
"interpreter": {
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment