Created
August 15, 2020 14:44
-
-
Save kdoodoo/9b1383e8c42ec1248bec5799facf5e40 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch import optim\n", | |
"import torch.nn.functional as F\n", | |
"from torchvision import datasets, transforms, models\n", | |
"from torchvision.models import resnet18,resnet34\n", | |
"from torch.utils.data import DataLoader\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_data_dir = 'img/AffectNet/train_CNN'\n", | |
"test_data_dir = 'img/AffectNet/test_CNN'\n", | |
"\n", | |
"transform = transforms.Compose([\n", | |
" transforms.Resize((128,128)),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_data = datasets.ImageFolder( root=train_data_dir, transform=transform)\n", | |
"test_data = datasets.ImageFolder(root=test_data_dir, transform=transform)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dataloader = DataLoader(dataset=train_data,batch_size=128, num_workers=16, shuffle=True)\n", | |
"test_dataloader = DataLoader(dataset=test_data, batch_size=128,num_workers=16, shuffle=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['A', 'B', 'C', 'D', 'E', 'F', 'G']\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"7" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"print(train_dataloader.dataset.classes)\n", | |
"num_class = len(train_dataloader.dataset.classes)\n", | |
"num_class" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Image batch dimensions: torch.Size([128, 3, 128, 128])\n", | |
"Image label dimensions: tensor([0, 2, 1, 4, 3, 6, 5, 4, 6, 1, 0, 5, 6, 3, 3, 1, 0, 0, 1, 3, 0, 1, 0, 2,\n", | |
" 0, 3, 1, 5, 5, 6, 6, 5, 2, 5, 5, 6, 4, 6, 5, 3, 0, 3, 1, 5, 4, 3, 6, 3,\n", | |
" 2, 3, 1, 1, 4, 3, 3, 6, 1, 5, 6, 1, 2, 6, 4, 0, 1, 0, 1, 4, 2, 5, 2, 1,\n", | |
" 6, 1, 3, 5, 1, 5, 6, 2, 5, 3, 5, 3, 3, 4, 0, 1, 6, 3, 6, 0, 5, 0, 0, 6,\n", | |
" 4, 1, 2, 6, 5, 2, 5, 0, 1, 4, 3, 2, 6, 4, 1, 4, 4, 0, 6, 6, 0, 0, 4, 5,\n", | |
" 5, 6, 3, 6, 1, 5, 0, 5])\n" | |
] | |
} | |
], | |
"source": [ | |
"for images, labels in train_dataloader: \n", | |
" print('Image batch dimensions:', images.shape)\n", | |
" print('Image label dimensions:', labels)\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset ImageFolder\n", | |
" Number of datapoints: 26250\n", | |
" Root location: img/AffectNet/train_CNN\n", | |
" StandardTransform\n", | |
"Transform: Compose(\n", | |
" Resize(size=(128, 128), interpolation=PIL.Image.BILINEAR)\n", | |
" ToTensor()\n", | |
" Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", | |
" )" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_dataloader.dataset\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset ImageFolder\n", | |
" Number of datapoints: 3500\n", | |
" Root location: img/AffectNet/test_CNN\n", | |
" StandardTransform\n", | |
"Transform: Compose(\n", | |
" Resize(size=(128, 128), interpolation=PIL.Image.BILINEAR)\n", | |
" ToTensor()\n", | |
" Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", | |
" )" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_dataloader.dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Linear(in_features=512, out_features=1000, bias=True)\n", | |
")" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"\n", | |
"\n", | |
"if torch.cuda.is_available():\n", | |
" device = torch.device(\"cuda:0\")\n", | |
"\n", | |
"\n", | |
"model = models.resnet18(pretrained=False,progress=True, )\n", | |
"model\n", | |
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"# model = resnet18(pretrained=False)\n", | |
"# is_cuda = torch.cuda.is_available()\n", | |
"\n", | |
"\n", | |
"# if is_cuda:\n", | |
"# model = model.cuda()\n", | |
" \n", | |
" \n", | |
"# # print(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"num_class" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# criterion = nn.CrossEntropyLoss()\n", | |
"# optimizer = optim.Adam(model.fc.parameters(), lr=0.001)\n", | |
"# model.to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Sequential(\n", | |
" (0): Linear(in_features=512, out_features=512, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=512, out_features=7, bias=True)\n", | |
" (3): LogSoftmax()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"for param in model.parameters():\n", | |
" param.requires_grad = False\n", | |
" \n", | |
"model.fc = nn.Sequential(nn.Linear(512, 512),\n", | |
" nn.ReLU(),\n", | |
"# nn.Dropout(0.2),\n", | |
" nn.Linear(512, num_class),\n", | |
" nn.LogSoftmax(dim=1))\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"# CrossEntropyLoss()\n", | |
"optimizer = optim.Adam(model.fc.parameters(), lr=0.001)\n", | |
"model.to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"device(type='cuda', index=0)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"device" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_epochs = 100\n", | |
"import time\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# def compute_accuracy(model, data_loader):\n", | |
"# model.eval()\n", | |
"# correct_pred, num_examples = 0, 0\n", | |
"# for i, (features, targets) in enumerate(data_loader):\n", | |
" \n", | |
"# features = features.to(device)\n", | |
"# targets = targets.to(device)\n", | |
"\n", | |
"# # logits, probas = model(features)\n", | |
"\n", | |
"# predicted= model(features)\n", | |
"# _, predicted_labels = torch.max(predicted, 1)\n", | |
"# num_examples += targets.size(0)\n", | |
"# correct_pred += (predicted_labels == targets).sum()\n", | |
"# return correct_pred.float()/num_examples * 100\n", | |
"\n", | |
"\n", | |
"# def compute_epoch_loss(model, data_loader):\n", | |
"# model.eval()\n", | |
"# curr_loss, num_examples = 0., 0\n", | |
"# with torch.no_grad():\n", | |
"# for features, targets in data_loader:\n", | |
"# features = features.to(device)\n", | |
"# targets = targets.to(device)\n", | |
"# logits, probas = model(features)\n", | |
"# predicted= model(features)\n", | |
"# loss = F.cross_entropy(logits, targets, reduction='sum')\n", | |
"# loss = criterion(predicted,targets)\n", | |
"# num_examples += targets.size(0)\n", | |
"# curr_loss += loss\n", | |
"\n", | |
"# curr_loss = curr_loss / num_examples\n", | |
"# return curr_loss\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [1/100], Step [100/206], Loss: 1.9217\n", | |
"Epoch [1/100], Step [200/206], Loss: 1.9400\n", | |
"Epoch [2/100], Step [100/206], Loss: 1.9089\n", | |
"Epoch [2/100], Step [200/206], Loss: 1.9466\n", | |
"Epoch [3/100], Step [100/206], Loss: 1.9415\n", | |
"Epoch [3/100], Step [200/206], Loss: 1.9512\n", | |
"Epoch [4/100], Step [100/206], Loss: 1.9334\n", | |
"Epoch [4/100], Step [200/206], Loss: 1.8898\n", | |
"Epoch [5/100], Step [100/206], Loss: 1.9213\n", | |
"Epoch [5/100], Step [200/206], Loss: 1.8805\n", | |
"Epoch [6/100], Step [100/206], Loss: 1.9121\n", | |
"Epoch [6/100], Step [200/206], Loss: 1.9151\n", | |
"Epoch [7/100], Step [100/206], Loss: 1.9224\n", | |
"Epoch [7/100], Step [200/206], Loss: 1.8909\n", | |
"Epoch [8/100], Step [100/206], Loss: 1.8988\n", | |
"Epoch [8/100], Step [200/206], Loss: 1.9335\n", | |
"Epoch [9/100], Step [100/206], Loss: 1.8870\n", | |
"Epoch [9/100], Step [200/206], Loss: 1.8923\n", | |
"Epoch [10/100], Step [100/206], Loss: 1.9120\n", | |
"Epoch [10/100], Step [200/206], Loss: 1.8732\n", | |
"Epoch [11/100], Step [100/206], Loss: 1.8551\n", | |
"Epoch [11/100], Step [200/206], Loss: 1.9099\n", | |
"Epoch [12/100], Step [100/206], Loss: 1.8975\n", | |
"Epoch [12/100], Step [200/206], Loss: 1.8868\n", | |
"Epoch [13/100], Step [100/206], Loss: 1.9112\n", | |
"Epoch [13/100], Step [200/206], Loss: 1.8963\n", | |
"Epoch [14/100], Step [100/206], Loss: 1.8264\n", | |
"Epoch [14/100], Step [200/206], Loss: 1.8338\n", | |
"Epoch [15/100], Step [100/206], Loss: 1.8355\n", | |
"Epoch [15/100], Step [200/206], Loss: 1.9304\n", | |
"Epoch [16/100], Step [100/206], Loss: 1.9264\n", | |
"Epoch [16/100], Step [200/206], Loss: 1.8295\n", | |
"Epoch [17/100], Step [100/206], Loss: 1.8843\n", | |
"Epoch [17/100], Step [200/206], Loss: 1.8461\n", | |
"Epoch [18/100], Step [100/206], Loss: 1.8738\n", | |
"Epoch [18/100], Step [200/206], Loss: 1.8952\n", | |
"Epoch [19/100], Step [100/206], Loss: 1.8616\n", | |
"Epoch [19/100], Step [200/206], Loss: 1.8769\n", | |
"Epoch [20/100], Step [100/206], Loss: 1.8121\n", | |
"Epoch [20/100], Step [200/206], Loss: 1.8620\n", | |
"Epoch [21/100], Step [100/206], Loss: 1.8006\n", | |
"Epoch [21/100], Step [200/206], Loss: 1.8902\n", | |
"Epoch [22/100], Step [100/206], Loss: 1.9003\n", | |
"Epoch [22/100], Step [200/206], Loss: 1.8980\n", | |
"Epoch [23/100], Step [100/206], Loss: 1.8741\n", | |
"Epoch [23/100], Step [200/206], Loss: 1.8395\n", | |
"Epoch [24/100], Step [100/206], Loss: 1.8697\n", | |
"Epoch [24/100], Step [200/206], Loss: 1.8561\n", | |
"Epoch [25/100], Step [100/206], Loss: 1.8814\n", | |
"Epoch [25/100], Step [200/206], Loss: 1.8965\n", | |
"Epoch [26/100], Step [100/206], Loss: 1.8482\n", | |
"Epoch [26/100], Step [200/206], Loss: 1.8450\n", | |
"Epoch [27/100], Step [100/206], Loss: 1.9332\n", | |
"Epoch [27/100], Step [200/206], Loss: 1.8982\n", | |
"Epoch [28/100], Step [100/206], Loss: 1.8137\n", | |
"Epoch [28/100], Step [200/206], Loss: 1.8262\n", | |
"Epoch [29/100], Step [100/206], Loss: 1.8102\n", | |
"Epoch [29/100], Step [200/206], Loss: 1.8241\n", | |
"Epoch [30/100], Step [100/206], Loss: 1.8278\n", | |
"Epoch [30/100], Step [200/206], Loss: 1.8998\n", | |
"Epoch [31/100], Step [100/206], Loss: 1.8288\n", | |
"Epoch [31/100], Step [200/206], Loss: 1.8316\n", | |
"Epoch [32/100], Step [100/206], Loss: 1.7759\n", | |
"Epoch [32/100], Step [200/206], Loss: 1.8147\n", | |
"Epoch [33/100], Step [100/206], Loss: 1.8950\n", | |
"Epoch [33/100], Step [200/206], Loss: 1.7506\n", | |
"Epoch [34/100], Step [100/206], Loss: 1.7319\n", | |
"Epoch [34/100], Step [200/206], Loss: 1.8149\n", | |
"Epoch [35/100], Step [100/206], Loss: 1.8667\n", | |
"Epoch [35/100], Step [200/206], Loss: 1.8774\n", | |
"Epoch [36/100], Step [100/206], Loss: 1.8703\n", | |
"Epoch [36/100], Step [200/206], Loss: 1.8300\n", | |
"Epoch [37/100], Step [100/206], Loss: 1.8249\n", | |
"Epoch [37/100], Step [200/206], Loss: 1.7675\n", | |
"Epoch [38/100], Step [100/206], Loss: 1.9187\n", | |
"Epoch [38/100], Step [200/206], Loss: 1.8346\n", | |
"Epoch [39/100], Step [100/206], Loss: 1.7732\n", | |
"Epoch [39/100], Step [200/206], Loss: 1.8172\n", | |
"Epoch [40/100], Step [100/206], Loss: 1.9019\n", | |
"Epoch [40/100], Step [200/206], Loss: 1.8095\n", | |
"Epoch [41/100], Step [100/206], Loss: 1.8056\n", | |
"Epoch [41/100], Step [200/206], Loss: 1.8398\n", | |
"Epoch [42/100], Step [100/206], Loss: 1.8673\n", | |
"Epoch [42/100], Step [200/206], Loss: 1.8623\n", | |
"Epoch [43/100], Step [100/206], Loss: 1.8089\n", | |
"Epoch [43/100], Step [200/206], Loss: 1.8264\n", | |
"Epoch [44/100], Step [100/206], Loss: 1.8061\n", | |
"Epoch [44/100], Step [200/206], Loss: 1.8330\n", | |
"Epoch [45/100], Step [100/206], Loss: 1.8320\n", | |
"Epoch [45/100], Step [200/206], Loss: 1.8466\n", | |
"Epoch [46/100], Step [100/206], Loss: 1.8238\n", | |
"Epoch [46/100], Step [200/206], Loss: 1.7436\n", | |
"Epoch [47/100], Step [100/206], Loss: 1.8641\n", | |
"Epoch [47/100], Step [200/206], Loss: 1.8573\n", | |
"Epoch [48/100], Step [100/206], Loss: 1.8456\n", | |
"Epoch [48/100], Step [200/206], Loss: 1.7787\n", | |
"Epoch [49/100], Step [100/206], Loss: 1.8252\n", | |
"Epoch [49/100], Step [200/206], Loss: 1.9172\n", | |
"Epoch [50/100], Step [100/206], Loss: 1.8456\n", | |
"Epoch [50/100], Step [200/206], Loss: 1.8413\n", | |
"Epoch [51/100], Step [100/206], Loss: 1.8063\n", | |
"Epoch [51/100], Step [200/206], Loss: 1.8427\n", | |
"Epoch [52/100], Step [100/206], Loss: 1.7799\n", | |
"Epoch [52/100], Step [200/206], Loss: 1.7938\n", | |
"Epoch [53/100], Step [100/206], Loss: 1.8103\n", | |
"Epoch [53/100], Step [200/206], Loss: 1.8889\n", | |
"Epoch [54/100], Step [100/206], Loss: 1.8157\n", | |
"Epoch [54/100], Step [200/206], Loss: 1.7702\n", | |
"Epoch [55/100], Step [100/206], Loss: 1.8287\n", | |
"Epoch [55/100], Step [200/206], Loss: 1.8358\n", | |
"Epoch [56/100], Step [100/206], Loss: 1.8687\n", | |
"Epoch [56/100], Step [200/206], Loss: 1.8552\n", | |
"Epoch [57/100], Step [100/206], Loss: 1.7932\n", | |
"Epoch [57/100], Step [200/206], Loss: 1.8077\n", | |
"Epoch [58/100], Step [100/206], Loss: 1.8008\n", | |
"Epoch [58/100], Step [200/206], Loss: 1.8095\n", | |
"Epoch [59/100], Step [100/206], Loss: 1.8237\n", | |
"Epoch [59/100], Step [200/206], Loss: 1.8513\n", | |
"Epoch [60/100], Step [100/206], Loss: 1.8412\n", | |
"Epoch [60/100], Step [200/206], Loss: 1.8951\n", | |
"Epoch [61/100], Step [100/206], Loss: 1.7768\n", | |
"Epoch [61/100], Step [200/206], Loss: 1.8098\n", | |
"Epoch [62/100], Step [100/206], Loss: 1.8031\n", | |
"Epoch [62/100], Step [200/206], Loss: 1.8305\n", | |
"Epoch [63/100], Step [100/206], Loss: 1.7894\n", | |
"Epoch [63/100], Step [200/206], Loss: 1.7866\n", | |
"Epoch [64/100], Step [100/206], Loss: 1.8474\n", | |
"Epoch [64/100], Step [200/206], Loss: 1.8278\n", | |
"Epoch [65/100], Step [100/206], Loss: 1.7913\n", | |
"Epoch [65/100], Step [200/206], Loss: 1.7857\n", | |
"Epoch [66/100], Step [100/206], Loss: 1.7911\n", | |
"Epoch [66/100], Step [200/206], Loss: 1.7780\n", | |
"Epoch [67/100], Step [100/206], Loss: 1.7642\n", | |
"Epoch [67/100], Step [200/206], Loss: 1.8193\n", | |
"Epoch [68/100], Step [100/206], Loss: 1.7941\n", | |
"Epoch [68/100], Step [200/206], Loss: 1.7852\n", | |
"Epoch [69/100], Step [100/206], Loss: 1.8234\n", | |
"Epoch [69/100], Step [200/206], Loss: 1.8224\n", | |
"Epoch [70/100], Step [100/206], Loss: 1.8101\n", | |
"Epoch [70/100], Step [200/206], Loss: 1.8094\n", | |
"Epoch [71/100], Step [100/206], Loss: 1.7375\n", | |
"Epoch [71/100], Step [200/206], Loss: 1.8148\n", | |
"Epoch [72/100], Step [100/206], Loss: 1.7975\n", | |
"Epoch [72/100], Step [200/206], Loss: 1.8560\n", | |
"Epoch [73/100], Step [100/206], Loss: 1.8892\n", | |
"Epoch [73/100], Step [200/206], Loss: 1.8031\n", | |
"Epoch [74/100], Step [100/206], Loss: 1.8147\n", | |
"Epoch [74/100], Step [200/206], Loss: 1.7783\n", | |
"Epoch [75/100], Step [100/206], Loss: 1.9001\n", | |
"Epoch [75/100], Step [200/206], Loss: 1.8443\n", | |
"Epoch [76/100], Step [100/206], Loss: 1.8809\n", | |
"Epoch [76/100], Step [200/206], Loss: 1.7478\n", | |
"Epoch [77/100], Step [100/206], Loss: 1.7588\n", | |
"Epoch [77/100], Step [200/206], Loss: 1.9256\n", | |
"Epoch [78/100], Step [100/206], Loss: 1.7766\n", | |
"Epoch [78/100], Step [200/206], Loss: 1.8041\n", | |
"Epoch [79/100], Step [100/206], Loss: 1.7835\n", | |
"Epoch [79/100], Step [200/206], Loss: 1.7466\n", | |
"Epoch [80/100], Step [100/206], Loss: 1.8114\n", | |
"Epoch [80/100], Step [200/206], Loss: 1.7969\n", | |
"Epoch [81/100], Step [100/206], Loss: 1.7738\n", | |
"Epoch [81/100], Step [200/206], Loss: 1.8206\n", | |
"Epoch [82/100], Step [100/206], Loss: 1.7776\n", | |
"Epoch [82/100], Step [200/206], Loss: 1.8304\n", | |
"Epoch [83/100], Step [100/206], Loss: 1.8344\n", | |
"Epoch [83/100], Step [200/206], Loss: 1.8109\n", | |
"Epoch [84/100], Step [100/206], Loss: 1.8179\n", | |
"Epoch [84/100], Step [200/206], Loss: 1.8043\n", | |
"Epoch [85/100], Step [100/206], Loss: 1.8179\n", | |
"Epoch [85/100], Step [200/206], Loss: 1.8325\n", | |
"Epoch [86/100], Step [100/206], Loss: 1.8183\n", | |
"Epoch [86/100], Step [200/206], Loss: 1.7328\n", | |
"Epoch [87/100], Step [100/206], Loss: 1.8620\n", | |
"Epoch [87/100], Step [200/206], Loss: 1.8000\n", | |
"Epoch [88/100], Step [100/206], Loss: 1.8483\n", | |
"Epoch [88/100], Step [200/206], Loss: 1.8043\n", | |
"Epoch [89/100], Step [100/206], Loss: 1.7693\n", | |
"Epoch [89/100], Step [200/206], Loss: 1.7481\n", | |
"Epoch [90/100], Step [100/206], Loss: 1.7677\n", | |
"Epoch [90/100], Step [200/206], Loss: 1.8045\n", | |
"Epoch [91/100], Step [100/206], Loss: 1.8099\n", | |
"Epoch [91/100], Step [200/206], Loss: 1.8099\n", | |
"Epoch [92/100], Step [100/206], Loss: 1.7886\n", | |
"Epoch [92/100], Step [200/206], Loss: 1.7793\n", | |
"Epoch [93/100], Step [100/206], Loss: 1.8118\n", | |
"Epoch [93/100], Step [200/206], Loss: 1.7652\n", | |
"Epoch [94/100], Step [100/206], Loss: 1.8959\n", | |
"Epoch [94/100], Step [200/206], Loss: 1.7397\n", | |
"Epoch [95/100], Step [100/206], Loss: 1.8318\n", | |
"Epoch [95/100], Step [200/206], Loss: 1.7495\n", | |
"Epoch [96/100], Step [100/206], Loss: 1.7748\n", | |
"Epoch [96/100], Step [200/206], Loss: 1.8398\n", | |
"Epoch [97/100], Step [100/206], Loss: 1.7984\n", | |
"Epoch [97/100], Step [200/206], Loss: 1.8364\n", | |
"Epoch [98/100], Step [100/206], Loss: 1.8306\n", | |
"Epoch [98/100], Step [200/206], Loss: 1.8240\n", | |
"Epoch [99/100], Step [100/206], Loss: 1.7435\n", | |
"Epoch [99/100], Step [200/206], Loss: 1.8296\n", | |
"Epoch [100/100], Step [100/206], Loss: 1.8072\n", | |
"Epoch [100/100], Step [200/206], Loss: 1.8313\n" | |
] | |
} | |
], | |
"source": [ | |
"# Train the model\n", | |
"total_step = len(train_dataloader)\n", | |
"for epoch in range(num_epochs):\n", | |
" for i, (images, labels) in enumerate(train_dataloader):\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" \n", | |
" # Forward pass\n", | |
" outputs = model(images)\n", | |
" loss = criterion(outputs, labels)\n", | |
" \n", | |
" # Backward and optimize\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" if (i+1) % 100 == 0:\n", | |
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n", | |
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test Accuracy : 22.62857142857143 %\n" | |
] | |
} | |
], | |
"source": [ | |
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n", | |
"with torch.no_grad():\n", | |
" correct = 0\n", | |
" total = 0\n", | |
" for images, labels in test_dataloader:\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" outputs = model(images)\n", | |
" _, predicted = torch.max(outputs.data, 1)\n", | |
" total += labels.size(0)\n", | |
" correct += (predicted == labels).sum().item()\n", | |
"\n", | |
" print('Test Accuracy : {} %'.format(100 * correct / total))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_100.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Sequential(\n", | |
" (0): Linear(in_features=512, out_features=512, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=512, out_features=7, bias=True)\n", | |
" (3): LogSoftmax()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = model\n", | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.load_state_dict(torch.load('./checkpoints/weight_resnet18_100.pth'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [1/100], Step [100/206], Loss: 1.7892\n", | |
"Epoch [1/100], Step [200/206], Loss: 1.7876\n", | |
"Epoch [2/100], Step [100/206], Loss: 1.8238\n", | |
"Epoch [2/100], Step [200/206], Loss: 1.8310\n", | |
"Epoch [3/100], Step [100/206], Loss: 1.8300\n", | |
"Epoch [3/100], Step [200/206], Loss: 1.7602\n", | |
"Epoch [4/100], Step [100/206], Loss: 1.7329\n", | |
"Epoch [4/100], Step [200/206], Loss: 1.8731\n", | |
"Epoch [5/100], Step [100/206], Loss: 1.8559\n", | |
"Epoch [5/100], Step [200/206], Loss: 1.7999\n", | |
"Epoch [6/100], Step [100/206], Loss: 1.7815\n", | |
"Epoch [6/100], Step [200/206], Loss: 1.8430\n", | |
"Epoch [7/100], Step [100/206], Loss: 1.8338\n", | |
"Epoch [7/100], Step [200/206], Loss: 1.7600\n", | |
"Epoch [8/100], Step [100/206], Loss: 1.8517\n", | |
"Epoch [8/100], Step [200/206], Loss: 1.7685\n", | |
"Epoch [9/100], Step [100/206], Loss: 1.8468\n", | |
"Epoch [9/100], Step [200/206], Loss: 1.7668\n", | |
"Epoch [10/100], Step [100/206], Loss: 1.7533\n", | |
"Epoch [10/100], Step [200/206], Loss: 1.8353\n", | |
"Epoch [11/100], Step [100/206], Loss: 1.6990\n", | |
"Epoch [11/100], Step [200/206], Loss: 1.7079\n", | |
"Epoch [12/100], Step [100/206], Loss: 1.7996\n", | |
"Epoch [12/100], Step [200/206], Loss: 1.6895\n", | |
"Epoch [13/100], Step [100/206], Loss: 1.8325\n", | |
"Epoch [13/100], Step [200/206], Loss: 1.8564\n", | |
"Epoch [14/100], Step [100/206], Loss: 1.8215\n", | |
"Epoch [14/100], Step [200/206], Loss: 1.7708\n", | |
"Epoch [15/100], Step [100/206], Loss: 1.7893\n", | |
"Epoch [15/100], Step [200/206], Loss: 1.8371\n", | |
"Epoch [16/100], Step [100/206], Loss: 1.8978\n", | |
"Epoch [16/100], Step [200/206], Loss: 1.8038\n", | |
"Epoch [17/100], Step [100/206], Loss: 1.7541\n", | |
"Epoch [17/100], Step [200/206], Loss: 1.8134\n", | |
"Epoch [18/100], Step [100/206], Loss: 1.8766\n", | |
"Epoch [18/100], Step [200/206], Loss: 1.7279\n", | |
"Epoch [19/100], Step [100/206], Loss: 1.8450\n", | |
"Epoch [19/100], Step [200/206], Loss: 1.8058\n", | |
"Epoch [20/100], Step [100/206], Loss: 1.7848\n", | |
"Epoch [20/100], Step [200/206], Loss: 1.8017\n", | |
"Epoch [21/100], Step [100/206], Loss: 1.8013\n", | |
"Epoch [21/100], Step [200/206], Loss: 1.7861\n", | |
"Epoch [22/100], Step [100/206], Loss: 1.7494\n", | |
"Epoch [22/100], Step [200/206], Loss: 1.9154\n", | |
"Epoch [23/100], Step [100/206], Loss: 1.8243\n", | |
"Epoch [23/100], Step [200/206], Loss: 1.7085\n", | |
"Epoch [24/100], Step [100/206], Loss: 1.8059\n", | |
"Epoch [24/100], Step [200/206], Loss: 1.7298\n", | |
"Epoch [25/100], Step [100/206], Loss: 1.8759\n", | |
"Epoch [25/100], Step [200/206], Loss: 1.8032\n", | |
"Epoch [26/100], Step [100/206], Loss: 1.8111\n", | |
"Epoch [26/100], Step [200/206], Loss: 1.8276\n", | |
"Epoch [27/100], Step [100/206], Loss: 1.8080\n", | |
"Epoch [27/100], Step [200/206], Loss: 1.8107\n", | |
"Epoch [28/100], Step [100/206], Loss: 1.8482\n", | |
"Epoch [28/100], Step [200/206], Loss: 1.7860\n", | |
"Epoch [29/100], Step [100/206], Loss: 1.8291\n", | |
"Epoch [29/100], Step [200/206], Loss: 1.8582\n", | |
"Epoch [30/100], Step [100/206], Loss: 1.8089\n", | |
"Epoch [30/100], Step [200/206], Loss: 1.8188\n", | |
"Epoch [31/100], Step [100/206], Loss: 1.7331\n", | |
"Epoch [31/100], Step [200/206], Loss: 1.7531\n", | |
"Epoch [32/100], Step [100/206], Loss: 1.7825\n", | |
"Epoch [32/100], Step [200/206], Loss: 1.8719\n", | |
"Epoch [33/100], Step [100/206], Loss: 1.6788\n", | |
"Epoch [33/100], Step [200/206], Loss: 1.7896\n", | |
"Epoch [34/100], Step [100/206], Loss: 1.8505\n", | |
"Epoch [34/100], Step [200/206], Loss: 1.7788\n", | |
"Epoch [35/100], Step [100/206], Loss: 1.7491\n", | |
"Epoch [35/100], Step [200/206], Loss: 1.7202\n", | |
"Epoch [36/100], Step [100/206], Loss: 1.7676\n", | |
"Epoch [36/100], Step [200/206], Loss: 1.8237\n", | |
"Epoch [37/100], Step [100/206], Loss: 1.6978\n", | |
"Epoch [37/100], Step [200/206], Loss: 1.8277\n", | |
"Epoch [38/100], Step [100/206], Loss: 1.8590\n", | |
"Epoch [38/100], Step [200/206], Loss: 1.8558\n", | |
"Epoch [39/100], Step [100/206], Loss: 1.7079\n", | |
"Epoch [39/100], Step [200/206], Loss: 1.8212\n", | |
"Epoch [40/100], Step [100/206], Loss: 1.8058\n", | |
"Epoch [40/100], Step [200/206], Loss: 1.7185\n", | |
"Epoch [41/100], Step [100/206], Loss: 1.7763\n", | |
"Epoch [41/100], Step [200/206], Loss: 1.8191\n", | |
"Epoch [42/100], Step [100/206], Loss: 1.8503\n", | |
"Epoch [42/100], Step [200/206], Loss: 1.7860\n", | |
"Epoch [43/100], Step [100/206], Loss: 1.8735\n", | |
"Epoch [43/100], Step [200/206], Loss: 1.8357\n", | |
"Epoch [44/100], Step [100/206], Loss: 1.8115\n", | |
"Epoch [44/100], Step [200/206], Loss: 1.7708\n", | |
"Epoch [45/100], Step [100/206], Loss: 1.7588\n", | |
"Epoch [45/100], Step [200/206], Loss: 1.8141\n", | |
"Epoch [46/100], Step [100/206], Loss: 1.7575\n", | |
"Epoch [46/100], Step [200/206], Loss: 1.7897\n", | |
"Epoch [47/100], Step [100/206], Loss: 1.8437\n", | |
"Epoch [47/100], Step [200/206], Loss: 1.7795\n", | |
"Epoch [48/100], Step [100/206], Loss: 1.7964\n", | |
"Epoch [48/100], Step [200/206], Loss: 1.7802\n", | |
"Epoch [49/100], Step [100/206], Loss: 1.7255\n", | |
"Epoch [49/100], Step [200/206], Loss: 1.8108\n", | |
"Epoch [50/100], Step [100/206], Loss: 1.7756\n", | |
"Epoch [50/100], Step [200/206], Loss: 1.8098\n", | |
"Epoch [51/100], Step [100/206], Loss: 1.7472\n", | |
"Epoch [51/100], Step [200/206], Loss: 1.8115\n", | |
"Epoch [52/100], Step [100/206], Loss: 1.8203\n", | |
"Epoch [52/100], Step [200/206], Loss: 1.7780\n", | |
"Epoch [53/100], Step [100/206], Loss: 1.7752\n", | |
"Epoch [53/100], Step [200/206], Loss: 1.8196\n", | |
"Epoch [54/100], Step [100/206], Loss: 1.8122\n", | |
"Epoch [54/100], Step [200/206], Loss: 1.6767\n", | |
"Epoch [55/100], Step [100/206], Loss: 1.7651\n", | |
"Epoch [55/100], Step [200/206], Loss: 1.7838\n", | |
"Epoch [56/100], Step [100/206], Loss: 1.8838\n", | |
"Epoch [56/100], Step [200/206], Loss: 1.7270\n", | |
"Epoch [57/100], Step [100/206], Loss: 1.7036\n", | |
"Epoch [57/100], Step [200/206], Loss: 1.7990\n", | |
"Epoch [58/100], Step [100/206], Loss: 1.8059\n", | |
"Epoch [58/100], Step [200/206], Loss: 1.7053\n", | |
"Epoch [59/100], Step [100/206], Loss: 1.7308\n", | |
"Epoch [59/100], Step [200/206], Loss: 1.7947\n", | |
"Epoch [60/100], Step [100/206], Loss: 1.7685\n", | |
"Epoch [60/100], Step [200/206], Loss: 1.7190\n", | |
"Epoch [61/100], Step [100/206], Loss: 1.7363\n", | |
"Epoch [61/100], Step [200/206], Loss: 1.8444\n", | |
"Epoch [62/100], Step [100/206], Loss: 1.7901\n", | |
"Epoch [62/100], Step [200/206], Loss: 1.7741\n", | |
"Epoch [63/100], Step [100/206], Loss: 1.7831\n", | |
"Epoch [63/100], Step [200/206], Loss: 1.7854\n", | |
"Epoch [64/100], Step [100/206], Loss: 1.7566\n", | |
"Epoch [64/100], Step [200/206], Loss: 1.8495\n", | |
"Epoch [65/100], Step [100/206], Loss: 1.8012\n", | |
"Epoch [65/100], Step [200/206], Loss: 1.9272\n", | |
"Epoch [66/100], Step [100/206], Loss: 1.8067\n", | |
"Epoch [66/100], Step [200/206], Loss: 1.7367\n", | |
"Epoch [67/100], Step [100/206], Loss: 1.7847\n", | |
"Epoch [67/100], Step [200/206], Loss: 1.7979\n", | |
"Epoch [68/100], Step [100/206], Loss: 1.7676\n", | |
"Epoch [68/100], Step [200/206], Loss: 1.7234\n", | |
"Epoch [69/100], Step [100/206], Loss: 1.8200\n", | |
"Epoch [69/100], Step [200/206], Loss: 1.7297\n", | |
"Epoch [70/100], Step [100/206], Loss: 1.7916\n", | |
"Epoch [70/100], Step [200/206], Loss: 1.7731\n", | |
"Epoch [71/100], Step [100/206], Loss: 1.7446\n", | |
"Epoch [71/100], Step [200/206], Loss: 1.7522\n", | |
"Epoch [72/100], Step [100/206], Loss: 1.8136\n", | |
"Epoch [72/100], Step [200/206], Loss: 1.7659\n", | |
"Epoch [73/100], Step [100/206], Loss: 1.7140\n", | |
"Epoch [73/100], Step [200/206], Loss: 1.8071\n", | |
"Epoch [74/100], Step [100/206], Loss: 1.7331\n", | |
"Epoch [74/100], Step [200/206], Loss: 1.7997\n", | |
"Epoch [75/100], Step [100/206], Loss: 1.8192\n", | |
"Epoch [75/100], Step [200/206], Loss: 1.6801\n", | |
"Epoch [76/100], Step [100/206], Loss: 1.7936\n", | |
"Epoch [76/100], Step [200/206], Loss: 1.7079\n", | |
"Epoch [77/100], Step [100/206], Loss: 1.8367\n", | |
"Epoch [77/100], Step [200/206], Loss: 1.6948\n", | |
"Epoch [78/100], Step [100/206], Loss: 1.7912\n", | |
"Epoch [78/100], Step [200/206], Loss: 1.8324\n", | |
"Epoch [79/100], Step [100/206], Loss: 1.7412\n", | |
"Epoch [79/100], Step [200/206], Loss: 1.7192\n", | |
"Epoch [80/100], Step [100/206], Loss: 1.7997\n", | |
"Epoch [80/100], Step [200/206], Loss: 1.7341\n", | |
"Epoch [81/100], Step [100/206], Loss: 1.8008\n", | |
"Epoch [81/100], Step [200/206], Loss: 1.8107\n", | |
"Epoch [82/100], Step [100/206], Loss: 1.7701\n", | |
"Epoch [82/100], Step [200/206], Loss: 1.7862\n", | |
"Epoch [83/100], Step [100/206], Loss: 1.7311\n", | |
"Epoch [83/100], Step [200/206], Loss: 1.8078\n", | |
"Epoch [84/100], Step [100/206], Loss: 1.7365\n", | |
"Epoch [84/100], Step [200/206], Loss: 1.7675\n", | |
"Epoch [85/100], Step [100/206], Loss: 1.8161\n", | |
"Epoch [85/100], Step [200/206], Loss: 1.8275\n", | |
"Epoch [86/100], Step [100/206], Loss: 1.7562\n", | |
"Epoch [86/100], Step [200/206], Loss: 1.7723\n", | |
"Epoch [87/100], Step [100/206], Loss: 1.7087\n", | |
"Epoch [87/100], Step [200/206], Loss: 1.7859\n", | |
"Epoch [88/100], Step [100/206], Loss: 1.7897\n", | |
"Epoch [88/100], Step [200/206], Loss: 1.7671\n", | |
"Epoch [89/100], Step [100/206], Loss: 1.7756\n", | |
"Epoch [89/100], Step [200/206], Loss: 1.7493\n", | |
"Epoch [90/100], Step [100/206], Loss: 1.7414\n", | |
"Epoch [90/100], Step [200/206], Loss: 1.7378\n", | |
"Epoch [91/100], Step [100/206], Loss: 1.6846\n", | |
"Epoch [91/100], Step [200/206], Loss: 1.7964\n", | |
"Epoch [92/100], Step [100/206], Loss: 1.7420\n", | |
"Epoch [92/100], Step [200/206], Loss: 1.7729\n", | |
"Epoch [93/100], Step [100/206], Loss: 1.8605\n", | |
"Epoch [93/100], Step [200/206], Loss: 1.7876\n", | |
"Epoch [94/100], Step [100/206], Loss: 1.7263\n", | |
"Epoch [94/100], Step [200/206], Loss: 1.7625\n", | |
"Epoch [95/100], Step [100/206], Loss: 1.7476\n", | |
"Epoch [95/100], Step [200/206], Loss: 1.8162\n", | |
"Epoch [96/100], Step [100/206], Loss: 1.6556\n", | |
"Epoch [96/100], Step [200/206], Loss: 1.7610\n", | |
"Epoch [97/100], Step [100/206], Loss: 1.7237\n", | |
"Epoch [97/100], Step [200/206], Loss: 1.7982\n", | |
"Epoch [98/100], Step [100/206], Loss: 1.7569\n", | |
"Epoch [98/100], Step [200/206], Loss: 1.7443\n", | |
"Epoch [99/100], Step [100/206], Loss: 1.8233\n", | |
"Epoch [99/100], Step [200/206], Loss: 1.7540\n", | |
"Epoch [100/100], Step [100/206], Loss: 1.7758\n", | |
"Epoch [100/100], Step [200/206], Loss: 1.7109\n" | |
] | |
} | |
], | |
"source": [ | |
"# Train the model\n", | |
"total_step = len(train_dataloader)\n", | |
"for epoch in range(num_epochs):\n", | |
" for i, (images, labels) in enumerate(train_dataloader):\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" \n", | |
" # Forward pass\n", | |
" outputs = model(images)\n", | |
" loss = criterion(outputs, labels)\n", | |
" \n", | |
" # Backward and optimize\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" if (i+1) % 100 == 0:\n", | |
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n", | |
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test Accuracy : 21.914285714285715 %\n" | |
] | |
} | |
], | |
"source": [ | |
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n", | |
"with torch.no_grad():\n", | |
" correct = 0\n", | |
" total = 0\n", | |
" for images, labels in test_dataloader:\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" outputs = model(images)\n", | |
" _, predicted = torch.max(outputs.data, 1)\n", | |
" total += labels.size(0)\n", | |
" correct += (predicted == labels).sum().item()\n", | |
"\n", | |
" print('Test Accuracy : {} %'.format(100 * correct / total))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_200.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.load_state_dict(torch.load('./checkpoints/weight_resnet18_200.pth'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [1/100], Step [100/206], Loss: 1.7775\n", | |
"Epoch [1/100], Step [200/206], Loss: 1.7994\n", | |
"Epoch [2/100], Step [100/206], Loss: 1.7386\n", | |
"Epoch [2/100], Step [200/206], Loss: 1.7325\n", | |
"Epoch [3/100], Step [100/206], Loss: 1.7318\n", | |
"Epoch [3/100], Step [200/206], Loss: 1.7492\n", | |
"Epoch [4/100], Step [100/206], Loss: 1.6084\n", | |
"Epoch [4/100], Step [200/206], Loss: 1.7708\n", | |
"Epoch [5/100], Step [100/206], Loss: 1.7332\n", | |
"Epoch [5/100], Step [200/206], Loss: 1.7659\n", | |
"Epoch [6/100], Step [100/206], Loss: 1.7814\n", | |
"Epoch [6/100], Step [200/206], Loss: 1.7406\n", | |
"Epoch [7/100], Step [100/206], Loss: 1.7844\n", | |
"Epoch [7/100], Step [200/206], Loss: 1.7175\n", | |
"Epoch [8/100], Step [100/206], Loss: 1.8441\n", | |
"Epoch [8/100], Step [200/206], Loss: 1.7527\n", | |
"Epoch [9/100], Step [100/206], Loss: 1.8416\n", | |
"Epoch [9/100], Step [200/206], Loss: 1.8365\n", | |
"Epoch [10/100], Step [100/206], Loss: 1.6929\n", | |
"Epoch [10/100], Step [200/206], Loss: 1.7640\n", | |
"Epoch [11/100], Step [100/206], Loss: 1.7789\n", | |
"Epoch [11/100], Step [200/206], Loss: 1.8365\n", | |
"Epoch [12/100], Step [100/206], Loss: 1.7304\n", | |
"Epoch [12/100], Step [200/206], Loss: 1.7912\n", | |
"Epoch [13/100], Step [100/206], Loss: 1.8394\n", | |
"Epoch [13/100], Step [200/206], Loss: 1.6970\n", | |
"Epoch [14/100], Step [100/206], Loss: 1.8486\n", | |
"Epoch [14/100], Step [200/206], Loss: 1.8152\n", | |
"Epoch [15/100], Step [100/206], Loss: 1.8517\n", | |
"Epoch [15/100], Step [200/206], Loss: 1.7902\n", | |
"Epoch [16/100], Step [100/206], Loss: 1.7725\n", | |
"Epoch [16/100], Step [200/206], Loss: 1.7615\n", | |
"Epoch [17/100], Step [100/206], Loss: 1.7559\n", | |
"Epoch [17/100], Step [200/206], Loss: 1.7292\n", | |
"Epoch [18/100], Step [100/206], Loss: 1.8214\n", | |
"Epoch [18/100], Step [200/206], Loss: 1.7762\n", | |
"Epoch [19/100], Step [100/206], Loss: 1.7168\n", | |
"Epoch [19/100], Step [200/206], Loss: 1.7787\n", | |
"Epoch [20/100], Step [100/206], Loss: 1.7660\n", | |
"Epoch [20/100], Step [200/206], Loss: 1.7539\n", | |
"Epoch [21/100], Step [100/206], Loss: 1.7864\n", | |
"Epoch [21/100], Step [200/206], Loss: 1.7234\n", | |
"Epoch [22/100], Step [100/206], Loss: 1.8251\n", | |
"Epoch [22/100], Step [200/206], Loss: 1.7325\n", | |
"Epoch [23/100], Step [100/206], Loss: 1.7693\n", | |
"Epoch [23/100], Step [200/206], Loss: 1.7419\n", | |
"Epoch [24/100], Step [100/206], Loss: 1.6545\n", | |
"Epoch [24/100], Step [200/206], Loss: 1.7839\n", | |
"Epoch [25/100], Step [100/206], Loss: 1.7508\n", | |
"Epoch [25/100], Step [200/206], Loss: 1.7846\n", | |
"Epoch [26/100], Step [100/206], Loss: 1.8689\n", | |
"Epoch [26/100], Step [200/206], Loss: 1.7939\n", | |
"Epoch [27/100], Step [100/206], Loss: 1.8019\n", | |
"Epoch [27/100], Step [200/206], Loss: 1.8169\n", | |
"Epoch [28/100], Step [100/206], Loss: 1.7630\n", | |
"Epoch [28/100], Step [200/206], Loss: 1.7747\n", | |
"Epoch [29/100], Step [100/206], Loss: 1.8262\n", | |
"Epoch [29/100], Step [200/206], Loss: 1.8054\n", | |
"Epoch [30/100], Step [100/206], Loss: 1.8358\n", | |
"Epoch [30/100], Step [200/206], Loss: 1.7541\n", | |
"Epoch [31/100], Step [100/206], Loss: 1.7286\n", | |
"Epoch [31/100], Step [200/206], Loss: 1.7669\n", | |
"Epoch [32/100], Step [100/206], Loss: 1.7877\n", | |
"Epoch [32/100], Step [200/206], Loss: 1.7734\n", | |
"Epoch [33/100], Step [100/206], Loss: 1.6754\n", | |
"Epoch [33/100], Step [200/206], Loss: 1.7035\n", | |
"Epoch [34/100], Step [100/206], Loss: 1.8467\n", | |
"Epoch [34/100], Step [200/206], Loss: 1.6976\n", | |
"Epoch [35/100], Step [100/206], Loss: 1.7958\n", | |
"Epoch [35/100], Step [200/206], Loss: 1.8057\n", | |
"Epoch [36/100], Step [100/206], Loss: 1.7282\n", | |
"Epoch [36/100], Step [200/206], Loss: 1.7345\n", | |
"Epoch [37/100], Step [100/206], Loss: 1.6798\n", | |
"Epoch [37/100], Step [200/206], Loss: 1.7780\n", | |
"Epoch [38/100], Step [100/206], Loss: 1.7703\n", | |
"Epoch [38/100], Step [200/206], Loss: 1.7031\n", | |
"Epoch [39/100], Step [100/206], Loss: 1.7540\n", | |
"Epoch [39/100], Step [200/206], Loss: 1.8360\n", | |
"Epoch [40/100], Step [100/206], Loss: 1.7261\n", | |
"Epoch [40/100], Step [200/206], Loss: 1.7430\n", | |
"Epoch [41/100], Step [100/206], Loss: 1.8863\n", | |
"Epoch [41/100], Step [200/206], Loss: 1.7337\n", | |
"Epoch [42/100], Step [100/206], Loss: 1.7018\n", | |
"Epoch [42/100], Step [200/206], Loss: 1.6827\n", | |
"Epoch [43/100], Step [100/206], Loss: 1.7575\n", | |
"Epoch [43/100], Step [200/206], Loss: 1.8266\n", | |
"Epoch [44/100], Step [100/206], Loss: 1.6754\n", | |
"Epoch [44/100], Step [200/206], Loss: 1.7436\n", | |
"Epoch [45/100], Step [100/206], Loss: 1.7745\n", | |
"Epoch [45/100], Step [200/206], Loss: 1.7723\n", | |
"Epoch [46/100], Step [100/206], Loss: 1.7355\n", | |
"Epoch [46/100], Step [200/206], Loss: 1.7777\n", | |
"Epoch [47/100], Step [100/206], Loss: 1.8403\n", | |
"Epoch [47/100], Step [200/206], Loss: 1.7781\n", | |
"Epoch [48/100], Step [100/206], Loss: 1.8090\n", | |
"Epoch [48/100], Step [200/206], Loss: 1.7719\n", | |
"Epoch [49/100], Step [100/206], Loss: 1.8550\n", | |
"Epoch [49/100], Step [200/206], Loss: 1.6970\n", | |
"Epoch [50/100], Step [100/206], Loss: 1.6394\n", | |
"Epoch [50/100], Step [200/206], Loss: 1.8137\n", | |
"Epoch [51/100], Step [100/206], Loss: 1.7478\n", | |
"Epoch [51/100], Step [200/206], Loss: 1.7640\n", | |
"Epoch [52/100], Step [100/206], Loss: 1.8412\n", | |
"Epoch [52/100], Step [200/206], Loss: 1.8284\n", | |
"Epoch [53/100], Step [100/206], Loss: 1.7049\n", | |
"Epoch [53/100], Step [200/206], Loss: 1.7652\n", | |
"Epoch [54/100], Step [100/206], Loss: 1.7729\n", | |
"Epoch [54/100], Step [200/206], Loss: 1.8220\n", | |
"Epoch [55/100], Step [100/206], Loss: 1.8077\n", | |
"Epoch [55/100], Step [200/206], Loss: 1.7978\n", | |
"Epoch [56/100], Step [100/206], Loss: 1.7188\n", | |
"Epoch [56/100], Step [200/206], Loss: 1.6348\n", | |
"Epoch [57/100], Step [100/206], Loss: 1.7600\n", | |
"Epoch [57/100], Step [200/206], Loss: 1.7314\n", | |
"Epoch [58/100], Step [100/206], Loss: 1.7389\n", | |
"Epoch [58/100], Step [200/206], Loss: 1.8066\n", | |
"Epoch [59/100], Step [100/206], Loss: 1.7768\n", | |
"Epoch [59/100], Step [200/206], Loss: 1.8010\n", | |
"Epoch [60/100], Step [100/206], Loss: 1.7392\n", | |
"Epoch [60/100], Step [200/206], Loss: 1.8117\n", | |
"Epoch [61/100], Step [100/206], Loss: 1.6771\n", | |
"Epoch [61/100], Step [200/206], Loss: 1.7958\n", | |
"Epoch [62/100], Step [100/206], Loss: 1.7026\n", | |
"Epoch [62/100], Step [200/206], Loss: 1.7667\n", | |
"Epoch [63/100], Step [100/206], Loss: 1.8149\n", | |
"Epoch [63/100], Step [200/206], Loss: 1.6892\n", | |
"Epoch [64/100], Step [100/206], Loss: 1.7962\n", | |
"Epoch [64/100], Step [200/206], Loss: 1.7722\n", | |
"Epoch [65/100], Step [100/206], Loss: 1.6924\n", | |
"Epoch [65/100], Step [200/206], Loss: 1.7129\n", | |
"Epoch [66/100], Step [100/206], Loss: 1.7348\n", | |
"Epoch [66/100], Step [200/206], Loss: 1.7210\n", | |
"Epoch [67/100], Step [100/206], Loss: 1.6795\n", | |
"Epoch [67/100], Step [200/206], Loss: 1.7293\n", | |
"Epoch [68/100], Step [100/206], Loss: 1.7481\n", | |
"Epoch [68/100], Step [200/206], Loss: 1.7611\n", | |
"Epoch [69/100], Step [100/206], Loss: 1.7859\n", | |
"Epoch [69/100], Step [200/206], Loss: 1.7177\n", | |
"Epoch [70/100], Step [100/206], Loss: 1.7398\n", | |
"Epoch [70/100], Step [200/206], Loss: 1.8317\n", | |
"Epoch [71/100], Step [100/206], Loss: 1.8460\n", | |
"Epoch [71/100], Step [200/206], Loss: 1.7464\n", | |
"Epoch [72/100], Step [100/206], Loss: 1.7416\n", | |
"Epoch [72/100], Step [200/206], Loss: 1.6381\n", | |
"Epoch [73/100], Step [100/206], Loss: 1.8355\n", | |
"Epoch [73/100], Step [200/206], Loss: 1.7312\n", | |
"Epoch [74/100], Step [100/206], Loss: 1.7605\n", | |
"Epoch [74/100], Step [200/206], Loss: 1.7273\n", | |
"Epoch [75/100], Step [100/206], Loss: 1.7278\n", | |
"Epoch [75/100], Step [200/206], Loss: 1.7584\n", | |
"Epoch [76/100], Step [100/206], Loss: 1.7839\n", | |
"Epoch [76/100], Step [200/206], Loss: 1.8740\n", | |
"Epoch [77/100], Step [100/206], Loss: 1.7364\n", | |
"Epoch [77/100], Step [200/206], Loss: 1.7574\n", | |
"Epoch [78/100], Step [100/206], Loss: 1.7734\n", | |
"Epoch [78/100], Step [200/206], Loss: 1.7331\n", | |
"Epoch [79/100], Step [100/206], Loss: 1.7919\n", | |
"Epoch [79/100], Step [200/206], Loss: 1.8142\n", | |
"Epoch [80/100], Step [100/206], Loss: 1.6907\n", | |
"Epoch [80/100], Step [200/206], Loss: 1.7179\n", | |
"Epoch [81/100], Step [100/206], Loss: 1.8401\n", | |
"Epoch [81/100], Step [200/206], Loss: 1.7304\n", | |
"Epoch [82/100], Step [100/206], Loss: 1.6853\n", | |
"Epoch [82/100], Step [200/206], Loss: 1.7725\n", | |
"Epoch [83/100], Step [100/206], Loss: 1.6901\n", | |
"Epoch [83/100], Step [200/206], Loss: 1.7018\n", | |
"Epoch [84/100], Step [100/206], Loss: 1.7844\n", | |
"Epoch [84/100], Step [200/206], Loss: 1.8158\n", | |
"Epoch [85/100], Step [100/206], Loss: 1.7045\n", | |
"Epoch [85/100], Step [200/206], Loss: 1.8302\n", | |
"Epoch [86/100], Step [100/206], Loss: 1.7720\n", | |
"Epoch [86/100], Step [200/206], Loss: 1.8571\n", | |
"Epoch [87/100], Step [100/206], Loss: 1.7617\n", | |
"Epoch [87/100], Step [200/206], Loss: 1.6629\n", | |
"Epoch [88/100], Step [100/206], Loss: 1.7431\n", | |
"Epoch [88/100], Step [200/206], Loss: 1.7406\n", | |
"Epoch [89/100], Step [100/206], Loss: 1.8519\n", | |
"Epoch [89/100], Step [200/206], Loss: 1.7814\n", | |
"Epoch [90/100], Step [100/206], Loss: 1.7153\n", | |
"Epoch [90/100], Step [200/206], Loss: 1.7120\n", | |
"Epoch [91/100], Step [100/206], Loss: 1.7705\n", | |
"Epoch [91/100], Step [200/206], Loss: 1.7705\n", | |
"Epoch [92/100], Step [100/206], Loss: 1.6527\n", | |
"Epoch [92/100], Step [200/206], Loss: 1.7528\n", | |
"Epoch [93/100], Step [100/206], Loss: 1.7260\n", | |
"Epoch [93/100], Step [200/206], Loss: 1.7796\n", | |
"Epoch [94/100], Step [100/206], Loss: 1.6830\n", | |
"Epoch [94/100], Step [200/206], Loss: 1.7714\n", | |
"Epoch [95/100], Step [100/206], Loss: 1.6935\n", | |
"Epoch [95/100], Step [200/206], Loss: 1.7616\n", | |
"Epoch [96/100], Step [100/206], Loss: 1.8042\n", | |
"Epoch [96/100], Step [200/206], Loss: 1.7847\n", | |
"Epoch [97/100], Step [100/206], Loss: 1.7990\n", | |
"Epoch [97/100], Step [200/206], Loss: 1.7717\n", | |
"Epoch [98/100], Step [100/206], Loss: 1.7043\n", | |
"Epoch [98/100], Step [200/206], Loss: 1.7463\n", | |
"Epoch [99/100], Step [100/206], Loss: 1.8212\n", | |
"Epoch [99/100], Step [200/206], Loss: 1.7586\n", | |
"Epoch [100/100], Step [100/206], Loss: 1.7591\n", | |
"Epoch [100/100], Step [200/206], Loss: 1.7037\n" | |
] | |
} | |
], | |
"source": [ | |
"# Train the model\n", | |
"total_step = len(train_dataloader)\n", | |
"for epoch in range(num_epochs):\n", | |
" for i, (images, labels) in enumerate(train_dataloader):\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" \n", | |
" # Forward pass\n", | |
" outputs = model(images)\n", | |
" loss = criterion(outputs, labels)\n", | |
" \n", | |
" # Backward and optimize\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" if (i+1) % 100 == 0:\n", | |
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n", | |
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_300.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test Accuracy : 21.714285714285715 %\n" | |
] | |
} | |
], | |
"source": [ | |
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n", | |
"with torch.no_grad():\n", | |
" correct = 0\n", | |
" total = 0\n", | |
" for images, labels in test_dataloader:\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" outputs = model(images)\n", | |
" _, predicted = torch.max(outputs.data, 1)\n", | |
" total += labels.size(0)\n", | |
" correct += (predicted == labels).sum().item()\n", | |
"\n", | |
" print('Test Accuracy : {} %'.format(100 * correct / total))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# model.save_weights('./checkpoints')\n", | |
"torch.save(model.state_dict(), './checkpoints/weight.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Sequential(\n", | |
" (0): Linear(in_features=512, out_features=512, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=512, out_features=7, bias=True)\n", | |
" (3): LogSoftmax()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = model\n", | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.load_state_dict(torch.load('./checkpoints/weight.pth'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 001/100 | Train: 30.198% | Loss: 0.014\n", | |
"Time elapsed: 1.94 min\n", | |
"Epoch: 002/100 | Train: 30.053% | Loss: 0.014\n", | |
"Time elapsed: 3.98 min\n", | |
"Epoch: 003/100 | Train: 30.008% | Loss: 0.014\n", | |
"Time elapsed: 6.02 min\n", | |
"Epoch: 004/100 | Train: 29.116% | Loss: 0.014\n", | |
"Time elapsed: 7.98 min\n", | |
"Epoch: 005/100 | Train: 29.398% | Loss: 0.014\n", | |
"Time elapsed: 9.91 min\n", | |
"Epoch: 006/100 | Train: 29.246% | Loss: 0.014\n", | |
"Time elapsed: 11.87 min\n", | |
"Epoch: 007/100 | Train: 30.590% | Loss: 0.014\n", | |
"Time elapsed: 13.86 min\n", | |
"Epoch: 008/100 | Train: 29.497% | Loss: 0.014\n", | |
"Time elapsed: 15.88 min\n", | |
"Epoch: 009/100 | Train: 28.693% | Loss: 0.014\n", | |
"Time elapsed: 17.95 min\n", | |
"Epoch: 010/100 | Train: 30.141% | Loss: 0.014\n", | |
"Time elapsed: 20.11 min\n", | |
"Epoch: 011/100 | Train: 29.025% | Loss: 0.014\n", | |
"Time elapsed: 22.04 min\n", | |
"Epoch: 012/100 | Train: 29.318% | Loss: 0.014\n", | |
"Time elapsed: 23.84 min\n", | |
"Epoch: 013/100 | Train: 30.545% | Loss: 0.014\n", | |
"Time elapsed: 25.67 min\n", | |
"Epoch: 014/100 | Train: 30.366% | Loss: 0.014\n", | |
"Time elapsed: 27.61 min\n", | |
"Epoch: 015/100 | Train: 30.278% | Loss: 0.014\n", | |
"Time elapsed: 29.42 min\n", | |
"Epoch: 016/100 | Train: 29.863% | Loss: 0.014\n", | |
"Time elapsed: 31.23 min\n", | |
"Epoch: 017/100 | Train: 29.745% | Loss: 0.014\n", | |
"Time elapsed: 33.00 min\n", | |
"Epoch: 018/100 | Train: 29.291% | Loss: 0.014\n", | |
"Time elapsed: 34.78 min\n", | |
"Epoch: 019/100 | Train: 30.731% | Loss: 0.014\n", | |
"Time elapsed: 36.57 min\n", | |
"Epoch: 020/100 | Train: 29.832% | Loss: 0.014\n", | |
"Time elapsed: 38.35 min\n", | |
"Epoch: 021/100 | Train: 30.830% | Loss: 0.014\n", | |
"Time elapsed: 40.13 min\n", | |
"Epoch: 022/100 | Train: 29.288% | Loss: 0.014\n", | |
"Time elapsed: 41.90 min\n", | |
"Epoch: 023/100 | Train: 29.630% | Loss: 0.014\n", | |
"Time elapsed: 43.67 min\n", | |
"Epoch: 024/100 | Train: 31.269% | Loss: 0.014\n", | |
"Time elapsed: 45.45 min\n", | |
"Epoch: 025/100 | Train: 31.044% | Loss: 0.014\n", | |
"Time elapsed: 47.21 min\n", | |
"Epoch: 026/100 | Train: 30.564% | Loss: 0.014\n", | |
"Time elapsed: 49.01 min\n", | |
"Epoch: 027/100 | Train: 30.568% | Loss: 0.014\n", | |
"Time elapsed: 50.80 min\n", | |
"Epoch: 028/100 | Train: 28.514% | Loss: 0.014\n", | |
"Time elapsed: 52.57 min\n", | |
"Epoch: 029/100 | Train: 31.093% | Loss: 0.014\n", | |
"Time elapsed: 54.37 min\n", | |
"Epoch: 030/100 | Train: 30.842% | Loss: 0.014\n", | |
"Time elapsed: 56.21 min\n", | |
"Epoch: 031/100 | Train: 29.550% | Loss: 0.014\n", | |
"Time elapsed: 58.02 min\n", | |
"Epoch: 032/100 | Train: 30.027% | Loss: 0.014\n", | |
"Time elapsed: 59.81 min\n", | |
"Epoch: 033/100 | Train: 30.293% | Loss: 0.014\n", | |
"Time elapsed: 61.60 min\n", | |
"Epoch: 034/100 | Train: 30.869% | Loss: 0.014\n", | |
"Time elapsed: 63.40 min\n", | |
"Epoch: 035/100 | Train: 30.838% | Loss: 0.014\n", | |
"Time elapsed: 65.18 min\n", | |
"Epoch: 036/100 | Train: 30.899% | Loss: 0.014\n", | |
"Time elapsed: 66.99 min\n", | |
"Epoch: 037/100 | Train: 30.450% | Loss: 0.014\n", | |
"Time elapsed: 68.80 min\n", | |
"Epoch: 038/100 | Train: 30.758% | Loss: 0.014\n", | |
"Time elapsed: 70.60 min\n", | |
"Epoch: 039/100 | Train: 30.907% | Loss: 0.014\n", | |
"Time elapsed: 72.40 min\n", | |
"Epoch: 040/100 | Train: 30.488% | Loss: 0.014\n", | |
"Time elapsed: 74.21 min\n", | |
"Epoch: 041/100 | Train: 29.303% | Loss: 0.014\n", | |
"Time elapsed: 76.01 min\n", | |
"Epoch: 042/100 | Train: 29.657% | Loss: 0.014\n", | |
"Time elapsed: 77.79 min\n", | |
"Epoch: 043/100 | Train: 30.549% | Loss: 0.014\n", | |
"Time elapsed: 79.59 min\n", | |
"Epoch: 044/100 | Train: 30.922% | Loss: 0.014\n", | |
"Time elapsed: 81.39 min\n", | |
"Epoch: 045/100 | Train: 30.663% | Loss: 0.014\n", | |
"Time elapsed: 83.17 min\n", | |
"Epoch: 046/100 | Train: 30.930% | Loss: 0.014\n", | |
"Time elapsed: 84.97 min\n", | |
"Epoch: 047/100 | Train: 29.642% | Loss: 0.014\n", | |
"Time elapsed: 86.73 min\n", | |
"Epoch: 048/100 | Train: 31.090% | Loss: 0.014\n", | |
"Time elapsed: 88.53 min\n", | |
"Epoch: 049/100 | Train: 31.166% | Loss: 0.014\n", | |
"Time elapsed: 90.31 min\n", | |
"Epoch: 050/100 | Train: 31.192% | Loss: 0.014\n", | |
"Time elapsed: 92.09 min\n", | |
"Epoch: 051/100 | Train: 30.263% | Loss: 0.014\n", | |
"Time elapsed: 93.85 min\n", | |
"Epoch: 052/100 | Train: 29.901% | Loss: 0.014\n", | |
"Time elapsed: 95.63 min\n", | |
"Epoch: 053/100 | Train: 31.059% | Loss: 0.014\n", | |
"Time elapsed: 97.44 min\n", | |
"Epoch: 054/100 | Train: 31.326% | Loss: 0.014\n", | |
"Time elapsed: 99.22 min\n", | |
"Epoch: 055/100 | Train: 30.507% | Loss: 0.014\n", | |
"Time elapsed: 101.04 min\n", | |
"Epoch: 056/100 | Train: 30.613% | Loss: 0.014\n", | |
"Time elapsed: 102.89 min\n", | |
"Epoch: 057/100 | Train: 29.848% | Loss: 0.014\n", | |
"Time elapsed: 104.72 min\n", | |
"Epoch: 058/100 | Train: 30.712% | Loss: 0.014\n", | |
"Time elapsed: 106.52 min\n", | |
"Epoch: 059/100 | Train: 29.440% | Loss: 0.014\n", | |
"Time elapsed: 108.29 min\n", | |
"Epoch: 060/100 | Train: 30.011% | Loss: 0.014\n", | |
"Time elapsed: 110.11 min\n", | |
"Epoch: 061/100 | Train: 31.230% | Loss: 0.014\n", | |
"Time elapsed: 111.89 min\n", | |
"Epoch: 062/100 | Train: 30.747% | Loss: 0.014\n", | |
"Time elapsed: 113.67 min\n", | |
"Epoch: 063/100 | Train: 29.230% | Loss: 0.014\n", | |
"Time elapsed: 115.44 min\n", | |
"Epoch: 064/100 | Train: 29.570% | Loss: 0.014\n", | |
"Time elapsed: 117.24 min\n", | |
"Epoch: 065/100 | Train: 29.973% | Loss: 0.014\n", | |
"Time elapsed: 119.02 min\n", | |
"Epoch: 066/100 | Train: 30.526% | Loss: 0.014\n", | |
"Time elapsed: 120.77 min\n", | |
"Epoch: 067/100 | Train: 30.663% | Loss: 0.014\n", | |
"Time elapsed: 122.54 min\n", | |
"Epoch: 068/100 | Train: 30.766% | Loss: 0.014\n", | |
"Time elapsed: 124.30 min\n", | |
"Epoch: 069/100 | Train: 31.196% | Loss: 0.014\n", | |
"Time elapsed: 126.08 min\n", | |
"Epoch: 070/100 | Train: 30.716% | Loss: 0.014\n", | |
"Time elapsed: 127.86 min\n", | |
"Epoch: 071/100 | Train: 30.983% | Loss: 0.014\n", | |
"Time elapsed: 129.66 min\n", | |
"Epoch: 072/100 | Train: 31.474% | Loss: 0.014\n", | |
"Time elapsed: 131.47 min\n", | |
"Epoch: 073/100 | Train: 31.173% | Loss: 0.014\n", | |
"Time elapsed: 133.27 min\n", | |
"Epoch: 074/100 | Train: 30.617% | Loss: 0.014\n", | |
"Time elapsed: 135.06 min\n", | |
"Epoch: 075/100 | Train: 31.451% | Loss: 0.014\n", | |
"Time elapsed: 136.89 min\n", | |
"Epoch: 076/100 | Train: 30.674% | Loss: 0.014\n", | |
"Time elapsed: 138.69 min\n", | |
"Epoch: 077/100 | Train: 30.678% | Loss: 0.014\n", | |
"Time elapsed: 140.51 min\n", | |
"Epoch: 078/100 | Train: 31.295% | Loss: 0.014\n", | |
"Time elapsed: 142.31 min\n", | |
"Epoch: 079/100 | Train: 30.583% | Loss: 0.014\n", | |
"Time elapsed: 144.11 min\n", | |
"Epoch: 080/100 | Train: 31.341% | Loss: 0.014\n", | |
"Time elapsed: 145.90 min\n", | |
"Epoch: 081/100 | Train: 29.307% | Loss: 0.014\n", | |
"Time elapsed: 147.70 min\n", | |
"Epoch: 082/100 | Train: 31.048% | Loss: 0.014\n", | |
"Time elapsed: 149.53 min\n", | |
"Epoch: 083/100 | Train: 31.128% | Loss: 0.014\n", | |
"Time elapsed: 151.34 min\n", | |
"Epoch: 084/100 | Train: 30.682% | Loss: 0.014\n", | |
"Time elapsed: 153.14 min\n", | |
"Epoch: 085/100 | Train: 31.589% | Loss: 0.014\n", | |
"Time elapsed: 154.94 min\n", | |
"Epoch: 086/100 | Train: 31.832% | Loss: 0.014\n", | |
"Time elapsed: 156.73 min\n", | |
"Epoch: 087/100 | Train: 31.676% | Loss: 0.014\n", | |
"Time elapsed: 158.54 min\n", | |
"Epoch: 088/100 | Train: 31.535% | Loss: 0.014\n", | |
"Time elapsed: 160.40 min\n", | |
"Epoch: 089/100 | Train: 30.206% | Loss: 0.014\n", | |
"Time elapsed: 162.21 min\n", | |
"Epoch: 090/100 | Train: 30.686% | Loss: 0.014\n", | |
"Time elapsed: 164.03 min\n", | |
"Epoch: 091/100 | Train: 30.880% | Loss: 0.014\n", | |
"Time elapsed: 165.84 min\n", | |
"Epoch: 092/100 | Train: 31.509% | Loss: 0.014\n", | |
"Time elapsed: 167.65 min\n", | |
"Epoch: 093/100 | Train: 30.434% | Loss: 0.014\n", | |
"Time elapsed: 169.46 min\n", | |
"Epoch: 094/100 | Train: 30.712% | Loss: 0.014\n", | |
"Time elapsed: 171.28 min\n", | |
"Epoch: 095/100 | Train: 30.945% | Loss: 0.014\n", | |
"Time elapsed: 173.10 min\n", | |
"Epoch: 096/100 | Train: 31.547% | Loss: 0.014\n", | |
"Time elapsed: 174.89 min\n", | |
"Epoch: 097/100 | Train: 31.680% | Loss: 0.014\n", | |
"Time elapsed: 176.69 min\n", | |
"Epoch: 098/100 | Train: 31.710% | Loss: 0.014\n", | |
"Time elapsed: 178.49 min\n", | |
"Epoch: 099/100 | Train: 29.208% | Loss: 0.014\n", | |
"Time elapsed: 180.32 min\n", | |
"Epoch: 100/100 | Train: 30.259% | Loss: 0.014\n", | |
"Time elapsed: 182.14 min\n", | |
"Total Training Time: 182.14 min\n" | |
] | |
} | |
], | |
"source": [ | |
"num_epochs = 100\n", | |
"import time \n", | |
"\n", | |
"start_time = time.time()\n", | |
"for epoch in range(num_epochs):\n", | |
" \n", | |
" model.train()\n", | |
" for batch_idx, (features, targets) in enumerate(train_dataloader):\n", | |
" \n", | |
" features = features.to(device)\n", | |
"# print(features.shape)\n", | |
"# break\n", | |
" targets = targets.to(device)\n", | |
" \n", | |
" ### FORWARD AND BACK PROP\n", | |
"# logits, probas = model(features)\n", | |
"# cost = F.cross_entropy(logits, targets)\n", | |
"\n", | |
" predicted = model(features)\n", | |
" loss = criterion(predicted,targets)\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
"# cost.backward()\n", | |
" \n", | |
" ### UPDATE MODEL PARAMETERS\n", | |
" optimizer.step()\n", | |
" \n", | |
" ### LOGGING\n", | |
"# if not batch_idx % 50:\n", | |
"# print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n", | |
"# %(epoch+1, num_epochs, batch_idx, \n", | |
"# len(train_dataloader), cost))\n", | |
"\n", | |
" model.eval()\n", | |
" with torch.set_grad_enabled(False): # save memory during inference\n", | |
" print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n", | |
" epoch+1, num_epochs, \n", | |
" compute_accuracy(model, train_dataloader),\n", | |
" compute_epoch_loss(model, train_dataloader)))\n", | |
"\n", | |
"\n", | |
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", | |
" \n", | |
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Sequential(\n", | |
" (0): Linear(in_features=512, out_features=512, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=512, out_features=7, bias=True)\n", | |
" (3): LogSoftmax()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = model\n", | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.load_state_dict(torch.load('./checkpoints/weight.pth'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 001/100 | Train: 30.678% | Loss: 0.014\n", | |
"Time elapsed: 1.81 min\n", | |
"Epoch: 002/100 | Train: 30.590% | Loss: 0.014\n", | |
"Time elapsed: 3.67 min\n", | |
"Epoch: 003/100 | Train: 32.015% | Loss: 0.014\n", | |
"Time elapsed: 5.49 min\n", | |
"Epoch: 004/100 | Train: 31.326% | Loss: 0.014\n", | |
"Time elapsed: 7.30 min\n", | |
"Epoch: 005/100 | Train: 31.718% | Loss: 0.014\n", | |
"Time elapsed: 9.07 min\n", | |
"Epoch: 006/100 | Train: 31.265% | Loss: 0.014\n", | |
"Time elapsed: 10.83 min\n", | |
"Epoch: 007/100 | Train: 31.688% | Loss: 0.014\n", | |
"Time elapsed: 12.61 min\n", | |
"Epoch: 008/100 | Train: 31.451% | Loss: 0.014\n", | |
"Time elapsed: 14.39 min\n", | |
"Epoch: 009/100 | Train: 31.512% | Loss: 0.014\n", | |
"Time elapsed: 16.22 min\n", | |
"Epoch: 010/100 | Train: 31.764% | Loss: 0.014\n", | |
"Time elapsed: 18.05 min\n", | |
"Epoch: 011/100 | Train: 31.040% | Loss: 0.014\n", | |
"Time elapsed: 19.99 min\n", | |
"Epoch: 012/100 | Train: 29.840% | Loss: 0.014\n", | |
"Time elapsed: 21.95 min\n", | |
"Epoch: 013/100 | Train: 31.920% | Loss: 0.014\n", | |
"Time elapsed: 23.88 min\n", | |
"Epoch: 014/100 | Train: 29.722% | Loss: 0.014\n", | |
"Time elapsed: 25.71 min\n", | |
"Epoch: 015/100 | Train: 31.931% | Loss: 0.014\n", | |
"Time elapsed: 27.67 min\n", | |
"Epoch: 016/100 | Train: 31.653% | Loss: 0.014\n", | |
"Time elapsed: 29.49 min\n", | |
"Epoch: 017/100 | Train: 30.358% | Loss: 0.014\n", | |
"Time elapsed: 31.43 min\n", | |
"Epoch: 018/100 | Train: 31.322% | Loss: 0.014\n", | |
"Time elapsed: 33.31 min\n", | |
"Epoch: 019/100 | Train: 31.604% | Loss: 0.014\n", | |
"Time elapsed: 35.20 min\n", | |
"Epoch: 020/100 | Train: 31.608% | Loss: 0.014\n", | |
"Time elapsed: 37.25 min\n", | |
"Epoch: 021/100 | Train: 32.156% | Loss: 0.014\n", | |
"Time elapsed: 39.25 min\n", | |
"Epoch: 022/100 | Train: 31.893% | Loss: 0.014\n", | |
"Time elapsed: 41.29 min\n", | |
"Epoch: 023/100 | Train: 31.840% | Loss: 0.014\n", | |
"Time elapsed: 43.31 min\n", | |
"Epoch: 024/100 | Train: 32.370% | Loss: 0.014\n", | |
"Time elapsed: 45.35 min\n", | |
"Epoch: 025/100 | Train: 31.855% | Loss: 0.014\n", | |
"Time elapsed: 47.39 min\n", | |
"Epoch: 026/100 | Train: 31.783% | Loss: 0.014\n", | |
"Time elapsed: 49.39 min\n", | |
"Epoch: 027/100 | Train: 31.840% | Loss: 0.014\n", | |
"Time elapsed: 51.40 min\n", | |
"Epoch: 028/100 | Train: 31.566% | Loss: 0.014\n", | |
"Time elapsed: 53.46 min\n", | |
"Epoch: 029/100 | Train: 31.669% | Loss: 0.014\n", | |
"Time elapsed: 55.52 min\n", | |
"Epoch: 030/100 | Train: 31.162% | Loss: 0.014\n", | |
"Time elapsed: 57.37 min\n", | |
"Epoch: 031/100 | Train: 31.802% | Loss: 0.014\n", | |
"Time elapsed: 59.27 min\n", | |
"Epoch: 032/100 | Train: 31.421% | Loss: 0.014\n", | |
"Time elapsed: 61.28 min\n", | |
"Epoch: 033/100 | Train: 31.703% | Loss: 0.014\n", | |
"Time elapsed: 63.29 min\n", | |
"Epoch: 034/100 | Train: 31.067% | Loss: 0.014\n", | |
"Time elapsed: 65.37 min\n", | |
"Epoch: 035/100 | Train: 31.181% | Loss: 0.014\n", | |
"Time elapsed: 67.46 min\n", | |
"Epoch: 036/100 | Train: 31.653% | Loss: 0.014\n", | |
"Time elapsed: 69.43 min\n", | |
"Epoch: 037/100 | Train: 30.693% | Loss: 0.014\n", | |
"Time elapsed: 71.40 min\n", | |
"Epoch: 038/100 | Train: 31.307% | Loss: 0.014\n", | |
"Time elapsed: 73.42 min\n", | |
"Epoch: 039/100 | Train: 29.703% | Loss: 0.014\n", | |
"Time elapsed: 75.49 min\n", | |
"Epoch: 040/100 | Train: 32.419% | Loss: 0.014\n", | |
"Time elapsed: 77.51 min\n", | |
"Epoch: 041/100 | Train: 30.350% | Loss: 0.014\n", | |
"Time elapsed: 79.58 min\n", | |
"Epoch: 042/100 | Train: 32.019% | Loss: 0.014\n", | |
"Time elapsed: 81.64 min\n", | |
"Epoch: 043/100 | Train: 31.973% | Loss: 0.014\n", | |
"Time elapsed: 83.67 min\n", | |
"Epoch: 044/100 | Train: 32.541% | Loss: 0.014\n", | |
"Time elapsed: 85.66 min\n", | |
"Epoch: 045/100 | Train: 31.341% | Loss: 0.014\n", | |
"Time elapsed: 87.62 min\n", | |
"Epoch: 046/100 | Train: 31.756% | Loss: 0.014\n", | |
"Time elapsed: 89.55 min\n", | |
"Epoch: 047/100 | Train: 29.806% | Loss: 0.014\n", | |
"Time elapsed: 91.37 min\n", | |
"Epoch: 048/100 | Train: 32.099% | Loss: 0.014\n", | |
"Time elapsed: 93.20 min\n", | |
"Epoch: 049/100 | Train: 31.238% | Loss: 0.014\n", | |
"Time elapsed: 95.00 min\n", | |
"Epoch: 050/100 | Train: 31.208% | Loss: 0.014\n", | |
"Time elapsed: 96.84 min\n", | |
"Epoch: 051/100 | Train: 32.248% | Loss: 0.014\n", | |
"Time elapsed: 98.75 min\n", | |
"Epoch: 052/100 | Train: 30.526% | Loss: 0.014\n", | |
"Time elapsed: 100.69 min\n", | |
"Epoch: 053/100 | Train: 31.855% | Loss: 0.014\n", | |
"Time elapsed: 102.54 min\n", | |
"Epoch: 054/100 | Train: 32.301% | Loss: 0.014\n", | |
"Time elapsed: 104.42 min\n", | |
"Epoch: 055/100 | Train: 31.425% | Loss: 0.014\n", | |
"Time elapsed: 106.43 min\n", | |
"Epoch: 056/100 | Train: 32.000% | Loss: 0.014\n", | |
"Time elapsed: 108.49 min\n", | |
"Epoch: 057/100 | Train: 31.246% | Loss: 0.014\n", | |
"Time elapsed: 110.58 min\n", | |
"Epoch: 058/100 | Train: 31.177% | Loss: 0.014\n", | |
"Time elapsed: 112.64 min\n", | |
"Epoch: 059/100 | Train: 32.152% | Loss: 0.014\n", | |
"Time elapsed: 114.69 min\n", | |
"Epoch: 060/100 | Train: 30.103% | Loss: 0.014\n", | |
"Time elapsed: 116.84 min\n", | |
"Epoch: 061/100 | Train: 31.897% | Loss: 0.014\n", | |
"Time elapsed: 118.97 min\n", | |
"Epoch: 062/100 | Train: 30.796% | Loss: 0.014\n", | |
"Time elapsed: 120.96 min\n", | |
"Epoch: 063/100 | Train: 31.653% | Loss: 0.014\n", | |
"Time elapsed: 122.73 min\n", | |
"Epoch: 064/100 | Train: 31.501% | Loss: 0.014\n", | |
"Time elapsed: 124.55 min\n", | |
"Epoch: 065/100 | Train: 31.722% | Loss: 0.014\n", | |
"Time elapsed: 126.34 min\n", | |
"Epoch: 066/100 | Train: 31.448% | Loss: 0.014\n", | |
"Time elapsed: 128.16 min\n", | |
"Epoch: 067/100 | Train: 31.349% | Loss: 0.014\n", | |
"Time elapsed: 129.94 min\n", | |
"Epoch: 068/100 | Train: 30.453% | Loss: 0.014\n", | |
"Time elapsed: 131.75 min\n", | |
"Epoch: 069/100 | Train: 31.238% | Loss: 0.014\n", | |
"Time elapsed: 133.54 min\n", | |
"Epoch: 070/100 | Train: 32.095% | Loss: 0.014\n", | |
"Time elapsed: 135.32 min\n", | |
"Epoch: 071/100 | Train: 30.949% | Loss: 0.014\n", | |
"Time elapsed: 137.11 min\n", | |
"Epoch: 072/100 | Train: 32.488% | Loss: 0.014\n", | |
"Time elapsed: 138.88 min\n", | |
"Epoch: 073/100 | Train: 31.718% | Loss: 0.014\n", | |
"Time elapsed: 140.64 min\n", | |
"Epoch: 074/100 | Train: 31.806% | Loss: 0.014\n", | |
"Time elapsed: 142.44 min\n", | |
"Epoch: 075/100 | Train: 31.478% | Loss: 0.014\n", | |
"Time elapsed: 144.22 min\n", | |
"Epoch: 076/100 | Train: 31.192% | Loss: 0.014\n", | |
"Time elapsed: 146.02 min\n", | |
"Epoch: 077/100 | Train: 32.099% | Loss: 0.014\n", | |
"Time elapsed: 147.80 min\n", | |
"Epoch: 078/100 | Train: 32.549% | Loss: 0.014\n", | |
"Time elapsed: 149.58 min\n", | |
"Epoch: 079/100 | Train: 30.777% | Loss: 0.014\n", | |
"Time elapsed: 151.41 min\n", | |
"Epoch: 080/100 | Train: 32.149% | Loss: 0.014\n", | |
"Time elapsed: 153.16 min\n", | |
"Epoch: 081/100 | Train: 32.331% | Loss: 0.013\n", | |
"Time elapsed: 154.93 min\n", | |
"Epoch: 082/100 | Train: 32.392% | Loss: 0.014\n", | |
"Time elapsed: 156.72 min\n", | |
"Epoch: 083/100 | Train: 32.571% | Loss: 0.014\n", | |
"Time elapsed: 158.50 min\n", | |
"Epoch: 084/100 | Train: 31.794% | Loss: 0.014\n", | |
"Time elapsed: 160.31 min\n", | |
"Epoch: 085/100 | Train: 32.213% | Loss: 0.014\n", | |
"Time elapsed: 162.08 min\n", | |
"Epoch: 086/100 | Train: 31.112% | Loss: 0.014\n", | |
"Time elapsed: 163.87 min\n", | |
"Epoch: 087/100 | Train: 30.842% | Loss: 0.014\n", | |
"Time elapsed: 165.65 min\n", | |
"Epoch: 088/100 | Train: 30.850% | Loss: 0.014\n", | |
"Time elapsed: 167.42 min\n", | |
"Epoch: 089/100 | Train: 31.977% | Loss: 0.014\n", | |
"Time elapsed: 169.22 min\n", | |
"Epoch: 090/100 | Train: 31.337% | Loss: 0.014\n", | |
"Time elapsed: 171.01 min\n", | |
"Epoch: 091/100 | Train: 32.019% | Loss: 0.014\n", | |
"Time elapsed: 172.81 min\n", | |
"Epoch: 092/100 | Train: 31.768% | Loss: 0.014\n", | |
"Time elapsed: 174.60 min\n", | |
"Epoch: 093/100 | Train: 32.415% | Loss: 0.013\n", | |
"Time elapsed: 176.37 min\n", | |
"Epoch: 094/100 | Train: 31.730% | Loss: 0.014\n", | |
"Time elapsed: 178.17 min\n", | |
"Epoch: 095/100 | Train: 32.038% | Loss: 0.014\n", | |
"Time elapsed: 179.99 min\n", | |
"Epoch: 096/100 | Train: 32.465% | Loss: 0.013\n", | |
"Time elapsed: 181.79 min\n", | |
"Epoch: 097/100 | Train: 31.242% | Loss: 0.014\n", | |
"Time elapsed: 183.56 min\n", | |
"Epoch: 098/100 | Train: 31.634% | Loss: 0.014\n", | |
"Time elapsed: 185.35 min\n", | |
"Epoch: 099/100 | Train: 32.301% | Loss: 0.013\n", | |
"Time elapsed: 187.13 min\n", | |
"Epoch: 100/100 | Train: 31.787% | Loss: 0.014\n", | |
"Time elapsed: 188.91 min\n", | |
"Total Training Time: 188.91 min\n" | |
] | |
} | |
], | |
"source": [ | |
"num_epochs = 100\n", | |
"import time \n", | |
"\n", | |
"start_time = time.time()\n", | |
"for epoch in range(num_epochs):\n", | |
" \n", | |
" model.train()\n", | |
" for batch_idx, (features, targets) in enumerate(train_dataloader):\n", | |
" \n", | |
" features = features.to(device)\n", | |
"# print(features.shape)\n", | |
"# break\n", | |
" targets = targets.to(device)\n", | |
" \n", | |
" ### FORWARD AND BACK PROP\n", | |
"# logits, probas = model(features)\n", | |
"# cost = F.cross_entropy(logits, targets)\n", | |
"\n", | |
" predicted = model(features)\n", | |
" loss = criterion(predicted,targets)\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
"# cost.backward()\n", | |
" \n", | |
" ### UPDATE MODEL PARAMETERS\n", | |
" optimizer.step()\n", | |
" \n", | |
" ### LOGGING\n", | |
"# if not batch_idx % 50:\n", | |
"# print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n", | |
"# %(epoch+1, num_epochs, batch_idx, \n", | |
"# len(train_dataloader), cost))\n", | |
"\n", | |
" model.eval()\n", | |
" with torch.set_grad_enabled(False): # save memory during inference\n", | |
" print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n", | |
" epoch+1, num_epochs, \n", | |
" compute_accuracy(model, train_dataloader),\n", | |
" compute_epoch_loss(model, train_dataloader)))\n", | |
"\n", | |
"\n", | |
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", | |
" \n", | |
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(model.state_dict(), './checkpoints/weight18-300.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-27-421d3408d711>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Training loss'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_losses\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Validation loss'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlegend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframeon\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mTypeError\u001b[0m: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first." | |
] | |
} | |
], | |
"source": [ | |
"b = predicted.numpy()\n", | |
"plt.plot(b, label='Training loss')\n", | |
"plt.plot(test_losses, label='Validation loss')\n", | |
"plt.legend(frameon=False)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment