Created
August 30, 2021 08:25
-
-
Save tttamaki/f839c9aa26b0174c85aa4359d8971a0c to your computer and use it in GitHub Desktop.
事前学習済みResNetを切り貼りしてABNを作ってみる
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"事前学習済みのResNet50をバラバラにして,ABNを作ってみる." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"source": [ | |
"!pip install torchinfo\n", | |
"!pip install git+https://github.com/facebookresearch/fvcore.git" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Defaulting to user installation because normal site-packages is not writeable\n", | |
"Requirement already satisfied: torchinfo in /opt/conda/lib/python3.8/site-packages (1.5.2)\n", | |
"\u001b[33mWARNING: You are using pip version 21.1.3; however, version 21.2.4 is available.\n", | |
"You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", | |
"Defaulting to user installation because normal site-packages is not writeable\n", | |
"Collecting git+https://github.com/facebookresearch/fvcore.git\n", | |
" Cloning https://github.com/facebookresearch/fvcore.git to /tmp/pip-req-build-khia8ald\n", | |
" Running command git clone -q https://github.com/facebookresearch/fvcore.git /tmp/pip-req-build-khia8ald\n", | |
"Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (1.20.1)\n", | |
"Requirement already satisfied: yacs>=0.1.6 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.1.8)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (5.4.1)\n", | |
"Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (4.53.0)\n", | |
"Requirement already satisfied: termcolor>=1.1 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (1.1.0)\n", | |
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (8.3.1)\n", | |
"Requirement already satisfied: tabulate in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.8.9)\n", | |
"Requirement already satisfied: iopath>=0.1.7 in /opt/conda/lib/python3.8/site-packages (from fvcore==0.1.5) (0.1.9)\n", | |
"Requirement already satisfied: portalocker in /opt/conda/lib/python3.8/site-packages (from iopath>=0.1.7->fvcore==0.1.5) (2.3.0)\n", | |
"\u001b[33mWARNING: You are using pip version 21.1.3; however, version 21.2.4 is available.\n", | |
"You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import torchvision\n", | |
"from torchvision import transforms\n", | |
"\n", | |
"import torchinfo\n", | |
"import copy\n", | |
"from tqdm import tqdm\n", | |
"# from tqdm.notebook import tqdm" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"データローダの準備" | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"source": [ | |
"transform_train = transforms.Compose([\n", | |
" transforms.RandomCrop(32, padding=1),\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n", | |
" std=[0.229, 0.224, 0.225]),\n", | |
" ])\n", | |
"transform_val = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean=[0.485, 0.456, 0.406], \n", | |
" std=[0.229, 0.224, 0.225]),\n", | |
" ])\n", | |
"\n", | |
"num_classes = 10\n", | |
"batch_size = 64\n", | |
"num_workers = 2\n", | |
"\n", | |
"train_set = torchvision.datasets.CIFAR10(\n", | |
" root='./data',\n", | |
" train=True,\n", | |
" download=True,\n", | |
" transform=transform_train\n", | |
")\n", | |
"val_set = torchvision.datasets.CIFAR10(\n", | |
" root='./data',\n", | |
" train=False,\n", | |
" download=True,\n", | |
" transform=transform_val\n", | |
")\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(\n", | |
" train_set, \n", | |
" batch_size=batch_size,\n", | |
" shuffle=True, \n", | |
" num_workers=num_workers)\n", | |
"val_loader = torch.utils.data.DataLoader(\n", | |
" val_set,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True, \n", | |
" num_workers=num_workers)\n" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Files already downloaded and verified\n", | |
"Files already downloaded and verified\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"まずはoriginalのResNet50.\n", | |
"どのモデルでもよいが,とりあえず\n", | |
"https://github.com/chenyaofo/pytorch-cifar-models\n", | |
"から,CIFAR10用のpretrainモデルを取得." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"source": [ | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |
"model_org = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n", | |
" \"cifar100_resnet56\", \n", | |
" pretrained=True)\n", | |
"model_org = model_org.to(device)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"summaryを確認." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"source": [ | |
"torchinfo.summary(\n", | |
" model_org,\n", | |
" (64, 3, 32, 32),\n", | |
" depth=2,\n", | |
" col_names=[\"input_size\",\n", | |
" \"output_size\",\n", | |
" \"kernel_size\"],\n", | |
" row_settings=(\"var_names\",)\n", | |
" )" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"===================================================================================================================\n", | |
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n", | |
"===================================================================================================================\n", | |
"CifarResNet -- -- --\n", | |
"├─Conv2d (conv1) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n", | |
"├─BatchNorm2d (bn1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n", | |
"├─ReLU (relu) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"├─Sequential (layer1) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"├─Sequential (layer2) [64, 16, 32, 32] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"├─Sequential (layer3) [64, 32, 16, 16] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"├─AdaptiveAvgPool2d (avgpool) [64, 64, 8, 8] [64, 64, 1, 1] --\n", | |
"├─Linear (fc) [64, 64] [64, 100] [64, 100]\n", | |
"===================================================================================================================\n", | |
"Total params: 861,620\n", | |
"Trainable params: 861,620\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (G): 8.05\n", | |
"===================================================================================================================\n", | |
"Input size (MB): 0.79\n", | |
"Forward/backward pass size (MB): 557.89\n", | |
"Params size (MB): 3.45\n", | |
"Estimated Total Size (MB): 562.13\n", | |
"===================================================================================================================" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"source": [ | |
"class ReconstructResNet50(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" model = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n", | |
" \"cifar100_resnet56\", \n", | |
" pretrained=True)\n", | |
"\n", | |
" self.resnet50_bottom_half = nn.Sequential(\n", | |
" model.conv1,\n", | |
" model.bn1,\n", | |
" model.relu,\n", | |
" model.layer1,\n", | |
" model.layer2,\n", | |
" model.layer3,\n", | |
" )\n", | |
"\n", | |
" self.resnet50_top_half = nn.Sequential(\n", | |
" model.avgpool,\n", | |
" nn.Flatten(),\n", | |
" model.fc\n", | |
" )\n", | |
"\n", | |
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n", | |
" x = self.resnet50_bottom_half(x)\n", | |
" x = self.resnet50_top_half(x)\n", | |
" return x" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"source": [ | |
"model_new = ReconstructResNet50()\n", | |
"model_new = model_new.to(device)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"source": [ | |
"torchinfo.summary(\n", | |
" model_new,\n", | |
" (64, 3, 32, 32),\n", | |
" depth=3,\n", | |
" col_names=[\"input_size\",\n", | |
" \"output_size\",\n", | |
" \"kernel_size\"],\n", | |
" row_settings=(\"var_names\",)\n", | |
" )" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"========================================================================================================================\n", | |
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n", | |
"========================================================================================================================\n", | |
"ReconstructResNet50 -- -- --\n", | |
"├─Sequential (resnet50_bottom_half) [64, 3, 32, 32] [64, 64, 8, 8] --\n", | |
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n", | |
"│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n", | |
"│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ │ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ │ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --\n", | |
"│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ │ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"├─Sequential (resnet50_top_half) [64, 64, 8, 8] [64, 100] --\n", | |
"│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --\n", | |
"│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --\n", | |
"│ └─Linear (2) [64, 64] [64, 100] [64, 100]\n", | |
"========================================================================================================================\n", | |
"Total params: 861,620\n", | |
"Trainable params: 861,620\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (G): 8.05\n", | |
"========================================================================================================================\n", | |
"Input size (MB): 0.79\n", | |
"Forward/backward pass size (MB): 557.89\n", | |
"Params size (MB): 3.45\n", | |
"Estimated Total Size (MB): 562.13\n", | |
"========================================================================================================================" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"ではデータを流し込んでみる.\n" | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"source": [ | |
"data = torch.randn(64, 3, 32, 32).to(device)" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"source": [ | |
"model_new.eval()\n", | |
"print(model_new(data).max(axis=1))\n", | |
"print(model_new(data)[:10, :5])" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.return_types.max(\n", | |
"values=tensor([10.9443, 12.4955, 13.0963, 10.5355, 10.8972, 11.2695, 10.1008, 11.6357,\n", | |
" 13.1638, 11.4222, 10.9982, 10.4561, 10.5532, 10.5805, 12.4815, 12.0393,\n", | |
" 13.4050, 10.7294, 10.4074, 11.0421, 12.0968, 11.0542, 10.8104, 13.8886,\n", | |
" 10.2925, 10.4051, 13.7182, 11.5399, 11.1024, 10.0979, 11.6706, 10.5574,\n", | |
" 12.0685, 13.5889, 14.2758, 12.7623, 11.5654, 9.5789, 11.1102, 10.3502,\n", | |
" 11.3670, 11.1035, 11.8346, 12.4422, 10.3109, 10.8246, 10.7729, 10.2088,\n", | |
" 11.5740, 11.5479, 11.4090, 12.2235, 10.6360, 14.7207, 10.1188, 10.0799,\n", | |
" 10.8706, 11.0431, 12.1793, 10.3579, 9.0835, 11.4152, 9.4510, 10.4705],\n", | |
" device='cuda:0', grad_fn=<MaxBackward0>),\n", | |
"indices=tensor([10, 82, 10, 57, 57, 57, 10, 57, 82, 83, 57, 57, 57, 10, 10, 10, 10, 57,\n", | |
" 82, 10, 10, 10, 82, 57, 10, 10, 83, 10, 57, 10, 10, 57, 10, 10, 10, 10,\n", | |
" 57, 57, 82, 10, 10, 57, 10, 57, 57, 57, 10, 82, 83, 9, 78, 10, 82, 10,\n", | |
" 10, 10, 57, 10, 57, 57, 10, 10, 82, 10], device='cuda:0'))\n", | |
"tensor([[ 6.0794e+00, -1.1938e-01, -1.2621e-01, -5.1893e-02, -4.3699e+00],\n", | |
" [ 6.5469e+00, -2.5111e-01, -6.4007e-01, -2.0096e+00, -3.5481e+00],\n", | |
" [ 5.7094e+00, 3.2534e-01, 5.0513e-01, 1.5728e-01, -4.6752e+00],\n", | |
" [ 6.8923e+00, 9.4589e-01, 1.2946e+00, 1.6836e+00, -2.7677e+00],\n", | |
" [ 6.2214e+00, 1.3083e+00, -1.7860e-01, 3.7956e-01, -4.4028e+00],\n", | |
" [ 5.7689e+00, 3.6209e-01, -1.4508e+00, 4.5521e-01, -4.2887e+00],\n", | |
" [ 5.1389e+00, -4.2951e-01, -9.8702e-01, 2.5621e-02, -3.7350e+00],\n", | |
" [ 5.1116e+00, 1.6931e+00, 6.2985e-04, -1.1464e+00, -4.3173e+00],\n", | |
" [ 5.7331e+00, -3.3522e-01, -1.2759e+00, -4.5809e-01, -3.7536e+00],\n", | |
" [ 4.2918e+00, 2.0628e+00, 7.2408e-02, 2.6386e-01, -4.4757e+00]],\n", | |
" device='cuda:0', grad_fn=<SliceBackward>)\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"source": [ | |
"model_org.eval()\n", | |
"print(model_org(data).max(axis=1))\n", | |
"print(model_org(data)[:10, :5])" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.return_types.max(\n", | |
"values=tensor([10.9443, 12.4955, 13.0963, 10.5355, 10.8972, 11.2695, 10.1008, 11.6357,\n", | |
" 13.1638, 11.4222, 10.9982, 10.4561, 10.5532, 10.5805, 12.4815, 12.0393,\n", | |
" 13.4050, 10.7294, 10.4074, 11.0421, 12.0968, 11.0542, 10.8104, 13.8886,\n", | |
" 10.2925, 10.4051, 13.7182, 11.5399, 11.1024, 10.0979, 11.6706, 10.5574,\n", | |
" 12.0685, 13.5889, 14.2758, 12.7623, 11.5654, 9.5789, 11.1102, 10.3502,\n", | |
" 11.3670, 11.1035, 11.8346, 12.4422, 10.3109, 10.8246, 10.7729, 10.2088,\n", | |
" 11.5740, 11.5479, 11.4090, 12.2235, 10.6360, 14.7207, 10.1188, 10.0799,\n", | |
" 10.8706, 11.0431, 12.1793, 10.3579, 9.0835, 11.4152, 9.4510, 10.4705],\n", | |
" device='cuda:0', grad_fn=<MaxBackward0>),\n", | |
"indices=tensor([10, 82, 10, 57, 57, 57, 10, 57, 82, 83, 57, 57, 57, 10, 10, 10, 10, 57,\n", | |
" 82, 10, 10, 10, 82, 57, 10, 10, 83, 10, 57, 10, 10, 57, 10, 10, 10, 10,\n", | |
" 57, 57, 82, 10, 10, 57, 10, 57, 57, 57, 10, 82, 83, 9, 78, 10, 82, 10,\n", | |
" 10, 10, 57, 10, 57, 57, 10, 10, 82, 10], device='cuda:0'))\n", | |
"tensor([[ 6.0794e+00, -1.1938e-01, -1.2621e-01, -5.1893e-02, -4.3699e+00],\n", | |
" [ 6.5469e+00, -2.5111e-01, -6.4007e-01, -2.0096e+00, -3.5481e+00],\n", | |
" [ 5.7094e+00, 3.2534e-01, 5.0513e-01, 1.5728e-01, -4.6752e+00],\n", | |
" [ 6.8923e+00, 9.4589e-01, 1.2946e+00, 1.6836e+00, -2.7677e+00],\n", | |
" [ 6.2214e+00, 1.3083e+00, -1.7860e-01, 3.7956e-01, -4.4028e+00],\n", | |
" [ 5.7689e+00, 3.6209e-01, -1.4508e+00, 4.5521e-01, -4.2887e+00],\n", | |
" [ 5.1389e+00, -4.2951e-01, -9.8702e-01, 2.5621e-02, -3.7350e+00],\n", | |
" [ 5.1116e+00, 1.6931e+00, 6.2985e-04, -1.1464e+00, -4.3173e+00],\n", | |
" [ 5.7331e+00, -3.3522e-01, -1.2759e+00, -4.5809e-01, -3.7536e+00],\n", | |
" [ 4.2918e+00, 2.0628e+00, 7.2408e-02, 2.6386e-01, -4.4757e+00]],\n", | |
" device='cuda:0', grad_fn=<SliceBackward>)\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"どちらも同じ出力が得られたことが分かる.\n", | |
"\n", | |
"ではoptimizerを設定して,勾配も一致するかどうか確認する." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"source": [ | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n", | |
"optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"source": [ | |
"target = torch.randint(num_classes, (64,)).to(device)" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"source": [ | |
"model_org.train()\n", | |
"optimizer_org.zero_grad()\n", | |
"output_org = model_org(data)\n", | |
"loss_org = criterion(output_org, target)\n", | |
"loss_org.backward()\n", | |
"loss_org" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(8.3755, device='cuda:0', grad_fn=<NllLossBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"source": [ | |
"model_new.train()\n", | |
"optimizer_new.zero_grad()\n", | |
"output_new = model_new(data)\n", | |
"loss_new = criterion(output_new, target)\n", | |
"loss_new.backward()\n", | |
"loss_new" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(8.3755, device='cuda:0', grad_fn=<NllLossBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 15 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"lossは一致した.\n", | |
"では重みとその勾配は一致するかどうかを確認する.対象は最初のconvに限定." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"source": [ | |
"model_new.resnet50_bottom_half[0].weight[0, 0, :, :]" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0220, -0.1125, -0.0150],\n", | |
" [-0.1283, -0.2929, -0.1424],\n", | |
" [-0.0580, -0.1338, -0.0578]], device='cuda:0', grad_fn=<SliceBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"source": [ | |
"model_org.conv1.weight[0, 0, :, :]" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0220, -0.1125, -0.0150],\n", | |
" [-0.1283, -0.2929, -0.1424],\n", | |
" [-0.0580, -0.1338, -0.0578]], device='cuda:0', grad_fn=<SliceBackward>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"source": [ | |
"model_new.resnet50_bottom_half[0].weight.grad[0, 0, :, :]" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.1526, 0.0275, -0.3145],\n", | |
" [-0.4756, 0.3217, 0.0965],\n", | |
" [-0.5709, -0.3611, -0.4316]], device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"source": [ | |
"model_org.conv1.weight.grad[0, 0, :, :]" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.1526, 0.0275, -0.3145],\n", | |
" [-0.4756, 0.3217, 0.0965],\n", | |
" [-0.5709, -0.3611, -0.4316]], device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"これで一致したことが確認できた.\n", | |
"\n", | |
"では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"source": [ | |
"class ABNResNet50(nn.Module):\n", | |
" def __init__(self, num_classes=100, pretrained=True):\n", | |
" super().__init__()\n", | |
" model = torch.hub.load(\"chenyaofo/pytorch-cifar-models\", \n", | |
" \"cifar100_resnet56\", \n", | |
" pretrained=pretrained)\n", | |
"\n", | |
" self.resnet50_bottom = nn.Sequential(\n", | |
" model.conv1,\n", | |
" model.bn1,\n", | |
" model.relu,\n", | |
" model.layer1,\n", | |
" model.layer2,\n", | |
" model.layer3,\n", | |
" )\n", | |
" r50b_out_features = \\\n", | |
" list(self.resnet50_bottom.modules())[-1].num_features\n", | |
"\n", | |
" self.resnet50_top = nn.Sequential(\n", | |
" model.avgpool,\n", | |
" nn.Flatten(),\n", | |
" model.fc\n", | |
" )\n", | |
"\n", | |
" self.attention_branch1 = nn.Sequential(\n", | |
" # ここは入出力サイズが同じlayer3[1:8]を再利用してしまおう\n", | |
" # deepcopyしないと,上で使ったものと重みが共有されてしまうので注意\n", | |
" copy.deepcopy(model.layer3[1:8]),\n", | |
"\n", | |
" nn.BatchNorm2d(r50b_out_features),\n", | |
" nn.Conv2d(r50b_out_features, num_classes, kernel_size=1),\n", | |
" nn.ReLU(inplace=True)\n", | |
" )\n", | |
" self.attention_branch2 = nn.Sequential(\n", | |
" nn.Conv2d(num_classes, 1, kernel_size=1),\n", | |
" nn.BatchNorm2d(1),\n", | |
" nn.Sigmoid()\n", | |
" )\n", | |
" self.attention_branch3 = nn.Sequential(\n", | |
" nn.Conv2d(num_classes, num_classes, kernel_size=1),\n", | |
" )\n", | |
"\n", | |
" def get_attn(self):\n", | |
" return self.attn\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.resnet50_bottom(x)\n", | |
"\n", | |
" ax = self.attention_branch1(x)\n", | |
" attn = self.attention_branch2(ax)\n", | |
"\n", | |
" x = x * attn\n", | |
" x = self.resnet50_top(x)\n", | |
"\n", | |
" ax = self.attention_branch3(ax)\n", | |
" ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze() # GAP\n", | |
"\n", | |
" self.attn = attn\n", | |
" return x, ax, attn\n" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"source": [ | |
"model = ABNResNet50(pretrained=True)\n", | |
"model.to(device)\n", | |
"model.train()\n", | |
"\n", | |
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Using cache found in /home/tamaki/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"ABNのパラメータを確認." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"source": [ | |
"torchinfo.summary(\n", | |
" model,\n", | |
" (64, 3, 32, 32),\n", | |
" depth=2,\n", | |
" col_names=[\"input_size\",\n", | |
" \"output_size\",\n", | |
" \"kernel_size\"],\n", | |
" row_settings=(\"var_names\",)\n", | |
" )" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"========================================================================================================================\n", | |
"Layer (type (var_name)) Input Shape Output Shape Kernel Shape\n", | |
"========================================================================================================================\n", | |
"ABNResNet50 -- -- --\n", | |
"├─Sequential (resnet50_bottom) [64, 3, 32, 32] [64, 64, 8, 8] --\n", | |
"│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]\n", | |
"│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]\n", | |
"│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --\n", | |
"│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --\n", | |
"│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --\n", | |
"├─Sequential (attention_branch1) [64, 64, 8, 8] [64, 100, 8, 8] --\n", | |
"│ └─Sequential (0) [64, 64, 8, 8] [64, 64, 8, 8] --\n", | |
"│ └─BatchNorm2d (1) [64, 64, 8, 8] [64, 64, 8, 8] [64]\n", | |
"│ └─Conv2d (2) [64, 64, 8, 8] [64, 100, 8, 8] [64, 100, 1, 1]\n", | |
"│ └─ReLU (3) [64, 100, 8, 8] [64, 100, 8, 8] --\n", | |
"├─Sequential (attention_branch2) [64, 100, 8, 8] [64, 1, 8, 8] --\n", | |
"│ └─Conv2d (0) [64, 100, 8, 8] [64, 1, 8, 8] [100, 1, 1, 1]\n", | |
"│ └─BatchNorm2d (1) [64, 1, 8, 8] [64, 1, 8, 8] [1]\n", | |
"│ └─Sigmoid (2) [64, 1, 8, 8] [64, 1, 8, 8] --\n", | |
"├─Sequential (resnet50_top) [64, 64, 8, 8] [64, 100] --\n", | |
"│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --\n", | |
"│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --\n", | |
"│ └─Linear (2) [64, 64] [64, 100] [64, 100]\n", | |
"├─Sequential (attention_branch3) [64, 100, 8, 8] [64, 100, 8, 8] --\n", | |
"│ └─Conv2d (0) [64, 100, 8, 8] [64, 100, 8, 8] [100, 100, 1, 1]\n", | |
"========================================================================================================================\n", | |
"Total params: 1,396,339\n", | |
"Trainable params: 1,396,339\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (G): 10.23\n", | |
"========================================================================================================================\n", | |
"Input size (MB): 0.79\n", | |
"Forward/backward pass size (MB): 625.33\n", | |
"Params size (MB): 5.59\n", | |
"Estimated Total Size (MB): 631.70\n", | |
"========================================================================================================================" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"便利関数を作っておく." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"source": [ | |
"class AverageMeter(object):\n", | |
" \"\"\"\n", | |
" Computes and stores the average and current value\n", | |
" Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380\n", | |
" https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59\n", | |
" \"\"\"\n", | |
" def __init__(self):\n", | |
" self.reset()\n", | |
"\n", | |
" def reset(self):\n", | |
" self.val = 0\n", | |
" self.avg = 0\n", | |
" self.sum = 0\n", | |
" self.count = 0\n", | |
"\n", | |
" def update(self, val, n=1):\n", | |
" if type(val) == torch.Tensor:\n", | |
" val = val.item()\n", | |
" self.val = val\n", | |
" self.sum += val * n\n", | |
" self.count += n\n", | |
" self.avg = self.sum / self.count\n", | |
"\n", | |
"\n", | |
"def accuracy(output, target, topk=(1,)):\n", | |
" \"\"\"\n", | |
" Computes the accuracy over the k top predictions for the specified values of k\n", | |
" https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411\n", | |
" \"\"\"\n", | |
" with torch.no_grad():\n", | |
" maxk = max(topk)\n", | |
" batch_size = target.size(0)\n", | |
"\n", | |
" _, pred = output.topk(maxk, 1, True, True)\n", | |
" pred = pred.t()\n", | |
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n", | |
"\n", | |
" res = []\n", | |
" for k in topk:\n", | |
" correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n", | |
" res.append(correct_k.mul_(100.0 / batch_size))\n", | |
" return res\n" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"source": [ | |
"criterion = nn.CrossEntropyLoss()\n", | |
"\n", | |
"epoch_num = 10" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"では学習と評価." | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"source": [ | |
"def one_epoch_process(data_loader, is_train=True):\n", | |
" \"\"\"train or evaluation for one epoch\n", | |
"\n", | |
" Args:\n", | |
" data_loader (DataLoader): data loader of train or val set\n", | |
" model (nn.Module): CNN model\n", | |
" is_train (bool, optional): flag of train or val. Defaults to True.\n", | |
" \"\"\"\n", | |
" with tqdm(enumerate(data_loader),\n", | |
" total=len(data_loader),\n", | |
" leave=True) as pbar_loss:\n", | |
"\n", | |
" log_loss = AverageMeter()\n", | |
" log_top1 = AverageMeter()\n", | |
" log_top5 = AverageMeter()\n", | |
" correct = AverageMeter()\n", | |
" model.train()\n", | |
" \n", | |
" for batch_idx, (image, label) in pbar_loss:\n", | |
" pbar_loss.set_description(\"[{}]\".format('train' if is_train else 'val'))\n", | |
"\n", | |
" current_batch_size = image.size(0)\n", | |
"\n", | |
" image, label = image.to(device), label.to(device)\n", | |
"\n", | |
" y, ay, _ = model(image)\n", | |
"\n", | |
" loss = criterion(y, label)\n", | |
" loss_ay = criterion(ay, label)\n", | |
" loss_all = (loss + loss_ay) / 2\n", | |
" log_loss.update(loss_all, current_batch_size)\n", | |
"\n", | |
" if is_train:\n", | |
" optimizer.zero_grad()\n", | |
" loss_all.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" acc1, acc5 = accuracy(y, label, topk=(1, 5))\n", | |
" log_top1.update(acc1, current_batch_size)\n", | |
" log_top5.update(acc5, current_batch_size)\n", | |
"\n", | |
" pbar_loss.set_postfix_str(\n", | |
" ' | loss={:6.04f} acc top1={:6.04f} top5={:6.04f}'\n", | |
" ' err top1={:6.04f} top5={:6.04f}'\n", | |
" ''.format(\n", | |
" log_loss.avg, \n", | |
" log_top1.avg,\n", | |
" log_top5.avg,\n", | |
" 100 - log_top1.avg,\n", | |
" 100 - log_top5.avg,\n", | |
" ))\n", | |
"\n", | |
"\n", | |
"def train():\n", | |
" one_epoch_process(train_loader, is_train=True)\n", | |
"\n", | |
"def val():\n", | |
" with torch.no_grad():\n", | |
" one_epoch_process(val_loader, is_train=False)\n", | |
"\n", | |
"\n", | |
"with tqdm(range(epoch_num)) as pbar_epoch:\n", | |
" for epoch in pbar_epoch:\n", | |
" pbar_epoch.set_description(\"[Epoch %d]\" % (epoch))\n", | |
"\n", | |
" train()\n", | |
"\n", | |
" val()\n" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"[train]: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s, | loss=0.7773 acc top1=78.2380 top5=97.6020 err top1=21.7620 top5=2.3980]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.31it/s, | loss=0.5266 acc top1=82.5600 top5=99.2100 err top1=17.4400 top5=0.7900]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.4180 acc top1=85.9420 top5=99.4820 err top1=14.0580 top5=0.5180]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.71it/s, | loss=0.4284 acc top1=85.2400 top5=99.4400 err top1=14.7600 top5=0.5600]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.82it/s, | loss=0.3465 acc top1=88.2740 top5=99.6420 err top1=11.7260 top5=0.3580]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.58it/s, | loss=0.4238 acc top1=85.8700 top5=99.4100 err top1=14.1300 top5=0.5900]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.98it/s, | loss=0.3014 acc top1=89.8680 top5=99.6780 err top1=10.1320 top5=0.3220]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.73it/s, | loss=0.3660 acc top1=87.6900 top5=99.6300 err top1=12.3100 top5=0.3700]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.2664 acc top1=90.9200 top5=99.7340 err top1=9.0800 top5=0.2660]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 31.18it/s, | loss=0.3526 acc top1=88.3100 top5=99.6000 err top1=11.6900 top5=0.4000]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2407 acc top1=91.9240 top5=99.8040 err top1=8.0760 top5=0.1960]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.45it/s, | loss=0.3370 acc top1=89.0900 top5=99.6700 err top1=10.9100 top5=0.3300]\n", | |
"[train]: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s, | loss=0.2209 acc top1=92.5220 top5=99.8440 err top1=7.4780 top5=0.1560]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 30.61it/s, | loss=0.2876 acc top1=90.2700 top5=99.7300 err top1=9.7300 top5=0.2700]\n", | |
"[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2077 acc top1=92.9760 top5=99.8880 err top1=7.0240 top5=0.1120]\n", | |
"[val]: 100%|██████████| 157/157 [00:05<00:00, 31.10it/s, | loss=0.3089 acc top1=90.0600 top5=99.7000 err top1=9.9400 top5=0.3000]\n", | |
"[Epoch 8]: 80%|████████ | 8/10 [08:45<02:11, 65.62s/it]" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"source": [], | |
"outputs": [], | |
"metadata": {} | |
} | |
], | |
"metadata": { | |
"orig_nbformat": 4, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.8", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3.8.8 64-bit ('base': conda)" | |
}, | |
"interpreter": { | |
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment