Skip to content

Instantly share code, notes, and snippets.

@kdoodoo
Created August 15, 2020 14:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kdoodoo/9b1383e8c42ec1248bec5799facf5e40 to your computer and use it in GitHub Desktop.
Save kdoodoo/9b1383e8c42ec1248bec5799facf5e40 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets, transforms, models\n",
"from torchvision.models import resnet18,resnet34\n",
"from torch.utils.data import DataLoader\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"train_data_dir = 'img/AffectNet/train_CNN'\n",
"test_data_dir = 'img/AffectNet/test_CNN'\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize((128,128)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_data = datasets.ImageFolder( root=train_data_dir, transform=transform)\n",
"test_data = datasets.ImageFolder(root=test_data_dir, transform=transform)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = DataLoader(dataset=train_data,batch_size=128, num_workers=16, shuffle=True)\n",
"test_dataloader = DataLoader(dataset=test_data, batch_size=128,num_workers=16, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['A', 'B', 'C', 'D', 'E', 'F', 'G']\n"
]
},
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(train_dataloader.dataset.classes)\n",
"num_class = len(train_dataloader.dataset.classes)\n",
"num_class"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image batch dimensions: torch.Size([128, 3, 128, 128])\n",
"Image label dimensions: tensor([0, 2, 1, 4, 3, 6, 5, 4, 6, 1, 0, 5, 6, 3, 3, 1, 0, 0, 1, 3, 0, 1, 0, 2,\n",
" 0, 3, 1, 5, 5, 6, 6, 5, 2, 5, 5, 6, 4, 6, 5, 3, 0, 3, 1, 5, 4, 3, 6, 3,\n",
" 2, 3, 1, 1, 4, 3, 3, 6, 1, 5, 6, 1, 2, 6, 4, 0, 1, 0, 1, 4, 2, 5, 2, 1,\n",
" 6, 1, 3, 5, 1, 5, 6, 2, 5, 3, 5, 3, 3, 4, 0, 1, 6, 3, 6, 0, 5, 0, 0, 6,\n",
" 4, 1, 2, 6, 5, 2, 5, 0, 1, 4, 3, 2, 6, 4, 1, 4, 4, 0, 6, 6, 0, 0, 4, 5,\n",
" 5, 6, 3, 6, 1, 5, 0, 5])\n"
]
}
],
"source": [
"for images, labels in train_dataloader: \n",
" print('Image batch dimensions:', images.shape)\n",
" print('Image label dimensions:', labels)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset ImageFolder\n",
" Number of datapoints: 26250\n",
" Root location: img/AffectNet/train_CNN\n",
" StandardTransform\n",
"Transform: Compose(\n",
" Resize(size=(128, 128), interpolation=PIL.Image.BILINEAR)\n",
" ToTensor()\n",
" Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" )"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataloader.dataset\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset ImageFolder\n",
" Number of datapoints: 3500\n",
" Root location: img/AffectNet/test_CNN\n",
" StandardTransform\n",
"Transform: Compose(\n",
" Resize(size=(128, 128), interpolation=PIL.Image.BILINEAR)\n",
" ToTensor()\n",
" Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" )"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dataloader.dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda:0\")\n",
"\n",
"\n",
"model = models.resnet18(pretrained=False,progress=True, )\n",
"model\n",
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"# model = resnet18(pretrained=False)\n",
"# is_cuda = torch.cuda.is_available()\n",
"\n",
"\n",
"# if is_cuda:\n",
"# model = model.cuda()\n",
" \n",
" \n",
"# # print(model)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_class"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# criterion = nn.CrossEntropyLoss()\n",
"# optimizer = optim.Adam(model.fc.parameters(), lr=0.001)\n",
"# model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=512, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=7, bias=True)\n",
" (3): LogSoftmax()\n",
" )\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for param in model.parameters():\n",
" param.requires_grad = False\n",
" \n",
"model.fc = nn.Sequential(nn.Linear(512, 512),\n",
" nn.ReLU(),\n",
"# nn.Dropout(0.2),\n",
" nn.Linear(512, num_class),\n",
" nn.LogSoftmax(dim=1))\n",
"criterion = nn.CrossEntropyLoss()\n",
"# CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.fc.parameters(), lr=0.001)\n",
"model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"num_epochs = 100\n",
"import time\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# def compute_accuracy(model, data_loader):\n",
"# model.eval()\n",
"# correct_pred, num_examples = 0, 0\n",
"# for i, (features, targets) in enumerate(data_loader):\n",
" \n",
"# features = features.to(device)\n",
"# targets = targets.to(device)\n",
"\n",
"# # logits, probas = model(features)\n",
"\n",
"# predicted= model(features)\n",
"# _, predicted_labels = torch.max(predicted, 1)\n",
"# num_examples += targets.size(0)\n",
"# correct_pred += (predicted_labels == targets).sum()\n",
"# return correct_pred.float()/num_examples * 100\n",
"\n",
"\n",
"# def compute_epoch_loss(model, data_loader):\n",
"# model.eval()\n",
"# curr_loss, num_examples = 0., 0\n",
"# with torch.no_grad():\n",
"# for features, targets in data_loader:\n",
"# features = features.to(device)\n",
"# targets = targets.to(device)\n",
"# logits, probas = model(features)\n",
"# predicted= model(features)\n",
"# loss = F.cross_entropy(logits, targets, reduction='sum')\n",
"# loss = criterion(predicted,targets)\n",
"# num_examples += targets.size(0)\n",
"# curr_loss += loss\n",
"\n",
"# curr_loss = curr_loss / num_examples\n",
"# return curr_loss\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/100], Step [100/206], Loss: 1.9217\n",
"Epoch [1/100], Step [200/206], Loss: 1.9400\n",
"Epoch [2/100], Step [100/206], Loss: 1.9089\n",
"Epoch [2/100], Step [200/206], Loss: 1.9466\n",
"Epoch [3/100], Step [100/206], Loss: 1.9415\n",
"Epoch [3/100], Step [200/206], Loss: 1.9512\n",
"Epoch [4/100], Step [100/206], Loss: 1.9334\n",
"Epoch [4/100], Step [200/206], Loss: 1.8898\n",
"Epoch [5/100], Step [100/206], Loss: 1.9213\n",
"Epoch [5/100], Step [200/206], Loss: 1.8805\n",
"Epoch [6/100], Step [100/206], Loss: 1.9121\n",
"Epoch [6/100], Step [200/206], Loss: 1.9151\n",
"Epoch [7/100], Step [100/206], Loss: 1.9224\n",
"Epoch [7/100], Step [200/206], Loss: 1.8909\n",
"Epoch [8/100], Step [100/206], Loss: 1.8988\n",
"Epoch [8/100], Step [200/206], Loss: 1.9335\n",
"Epoch [9/100], Step [100/206], Loss: 1.8870\n",
"Epoch [9/100], Step [200/206], Loss: 1.8923\n",
"Epoch [10/100], Step [100/206], Loss: 1.9120\n",
"Epoch [10/100], Step [200/206], Loss: 1.8732\n",
"Epoch [11/100], Step [100/206], Loss: 1.8551\n",
"Epoch [11/100], Step [200/206], Loss: 1.9099\n",
"Epoch [12/100], Step [100/206], Loss: 1.8975\n",
"Epoch [12/100], Step [200/206], Loss: 1.8868\n",
"Epoch [13/100], Step [100/206], Loss: 1.9112\n",
"Epoch [13/100], Step [200/206], Loss: 1.8963\n",
"Epoch [14/100], Step [100/206], Loss: 1.8264\n",
"Epoch [14/100], Step [200/206], Loss: 1.8338\n",
"Epoch [15/100], Step [100/206], Loss: 1.8355\n",
"Epoch [15/100], Step [200/206], Loss: 1.9304\n",
"Epoch [16/100], Step [100/206], Loss: 1.9264\n",
"Epoch [16/100], Step [200/206], Loss: 1.8295\n",
"Epoch [17/100], Step [100/206], Loss: 1.8843\n",
"Epoch [17/100], Step [200/206], Loss: 1.8461\n",
"Epoch [18/100], Step [100/206], Loss: 1.8738\n",
"Epoch [18/100], Step [200/206], Loss: 1.8952\n",
"Epoch [19/100], Step [100/206], Loss: 1.8616\n",
"Epoch [19/100], Step [200/206], Loss: 1.8769\n",
"Epoch [20/100], Step [100/206], Loss: 1.8121\n",
"Epoch [20/100], Step [200/206], Loss: 1.8620\n",
"Epoch [21/100], Step [100/206], Loss: 1.8006\n",
"Epoch [21/100], Step [200/206], Loss: 1.8902\n",
"Epoch [22/100], Step [100/206], Loss: 1.9003\n",
"Epoch [22/100], Step [200/206], Loss: 1.8980\n",
"Epoch [23/100], Step [100/206], Loss: 1.8741\n",
"Epoch [23/100], Step [200/206], Loss: 1.8395\n",
"Epoch [24/100], Step [100/206], Loss: 1.8697\n",
"Epoch [24/100], Step [200/206], Loss: 1.8561\n",
"Epoch [25/100], Step [100/206], Loss: 1.8814\n",
"Epoch [25/100], Step [200/206], Loss: 1.8965\n",
"Epoch [26/100], Step [100/206], Loss: 1.8482\n",
"Epoch [26/100], Step [200/206], Loss: 1.8450\n",
"Epoch [27/100], Step [100/206], Loss: 1.9332\n",
"Epoch [27/100], Step [200/206], Loss: 1.8982\n",
"Epoch [28/100], Step [100/206], Loss: 1.8137\n",
"Epoch [28/100], Step [200/206], Loss: 1.8262\n",
"Epoch [29/100], Step [100/206], Loss: 1.8102\n",
"Epoch [29/100], Step [200/206], Loss: 1.8241\n",
"Epoch [30/100], Step [100/206], Loss: 1.8278\n",
"Epoch [30/100], Step [200/206], Loss: 1.8998\n",
"Epoch [31/100], Step [100/206], Loss: 1.8288\n",
"Epoch [31/100], Step [200/206], Loss: 1.8316\n",
"Epoch [32/100], Step [100/206], Loss: 1.7759\n",
"Epoch [32/100], Step [200/206], Loss: 1.8147\n",
"Epoch [33/100], Step [100/206], Loss: 1.8950\n",
"Epoch [33/100], Step [200/206], Loss: 1.7506\n",
"Epoch [34/100], Step [100/206], Loss: 1.7319\n",
"Epoch [34/100], Step [200/206], Loss: 1.8149\n",
"Epoch [35/100], Step [100/206], Loss: 1.8667\n",
"Epoch [35/100], Step [200/206], Loss: 1.8774\n",
"Epoch [36/100], Step [100/206], Loss: 1.8703\n",
"Epoch [36/100], Step [200/206], Loss: 1.8300\n",
"Epoch [37/100], Step [100/206], Loss: 1.8249\n",
"Epoch [37/100], Step [200/206], Loss: 1.7675\n",
"Epoch [38/100], Step [100/206], Loss: 1.9187\n",
"Epoch [38/100], Step [200/206], Loss: 1.8346\n",
"Epoch [39/100], Step [100/206], Loss: 1.7732\n",
"Epoch [39/100], Step [200/206], Loss: 1.8172\n",
"Epoch [40/100], Step [100/206], Loss: 1.9019\n",
"Epoch [40/100], Step [200/206], Loss: 1.8095\n",
"Epoch [41/100], Step [100/206], Loss: 1.8056\n",
"Epoch [41/100], Step [200/206], Loss: 1.8398\n",
"Epoch [42/100], Step [100/206], Loss: 1.8673\n",
"Epoch [42/100], Step [200/206], Loss: 1.8623\n",
"Epoch [43/100], Step [100/206], Loss: 1.8089\n",
"Epoch [43/100], Step [200/206], Loss: 1.8264\n",
"Epoch [44/100], Step [100/206], Loss: 1.8061\n",
"Epoch [44/100], Step [200/206], Loss: 1.8330\n",
"Epoch [45/100], Step [100/206], Loss: 1.8320\n",
"Epoch [45/100], Step [200/206], Loss: 1.8466\n",
"Epoch [46/100], Step [100/206], Loss: 1.8238\n",
"Epoch [46/100], Step [200/206], Loss: 1.7436\n",
"Epoch [47/100], Step [100/206], Loss: 1.8641\n",
"Epoch [47/100], Step [200/206], Loss: 1.8573\n",
"Epoch [48/100], Step [100/206], Loss: 1.8456\n",
"Epoch [48/100], Step [200/206], Loss: 1.7787\n",
"Epoch [49/100], Step [100/206], Loss: 1.8252\n",
"Epoch [49/100], Step [200/206], Loss: 1.9172\n",
"Epoch [50/100], Step [100/206], Loss: 1.8456\n",
"Epoch [50/100], Step [200/206], Loss: 1.8413\n",
"Epoch [51/100], Step [100/206], Loss: 1.8063\n",
"Epoch [51/100], Step [200/206], Loss: 1.8427\n",
"Epoch [52/100], Step [100/206], Loss: 1.7799\n",
"Epoch [52/100], Step [200/206], Loss: 1.7938\n",
"Epoch [53/100], Step [100/206], Loss: 1.8103\n",
"Epoch [53/100], Step [200/206], Loss: 1.8889\n",
"Epoch [54/100], Step [100/206], Loss: 1.8157\n",
"Epoch [54/100], Step [200/206], Loss: 1.7702\n",
"Epoch [55/100], Step [100/206], Loss: 1.8287\n",
"Epoch [55/100], Step [200/206], Loss: 1.8358\n",
"Epoch [56/100], Step [100/206], Loss: 1.8687\n",
"Epoch [56/100], Step [200/206], Loss: 1.8552\n",
"Epoch [57/100], Step [100/206], Loss: 1.7932\n",
"Epoch [57/100], Step [200/206], Loss: 1.8077\n",
"Epoch [58/100], Step [100/206], Loss: 1.8008\n",
"Epoch [58/100], Step [200/206], Loss: 1.8095\n",
"Epoch [59/100], Step [100/206], Loss: 1.8237\n",
"Epoch [59/100], Step [200/206], Loss: 1.8513\n",
"Epoch [60/100], Step [100/206], Loss: 1.8412\n",
"Epoch [60/100], Step [200/206], Loss: 1.8951\n",
"Epoch [61/100], Step [100/206], Loss: 1.7768\n",
"Epoch [61/100], Step [200/206], Loss: 1.8098\n",
"Epoch [62/100], Step [100/206], Loss: 1.8031\n",
"Epoch [62/100], Step [200/206], Loss: 1.8305\n",
"Epoch [63/100], Step [100/206], Loss: 1.7894\n",
"Epoch [63/100], Step [200/206], Loss: 1.7866\n",
"Epoch [64/100], Step [100/206], Loss: 1.8474\n",
"Epoch [64/100], Step [200/206], Loss: 1.8278\n",
"Epoch [65/100], Step [100/206], Loss: 1.7913\n",
"Epoch [65/100], Step [200/206], Loss: 1.7857\n",
"Epoch [66/100], Step [100/206], Loss: 1.7911\n",
"Epoch [66/100], Step [200/206], Loss: 1.7780\n",
"Epoch [67/100], Step [100/206], Loss: 1.7642\n",
"Epoch [67/100], Step [200/206], Loss: 1.8193\n",
"Epoch [68/100], Step [100/206], Loss: 1.7941\n",
"Epoch [68/100], Step [200/206], Loss: 1.7852\n",
"Epoch [69/100], Step [100/206], Loss: 1.8234\n",
"Epoch [69/100], Step [200/206], Loss: 1.8224\n",
"Epoch [70/100], Step [100/206], Loss: 1.8101\n",
"Epoch [70/100], Step [200/206], Loss: 1.8094\n",
"Epoch [71/100], Step [100/206], Loss: 1.7375\n",
"Epoch [71/100], Step [200/206], Loss: 1.8148\n",
"Epoch [72/100], Step [100/206], Loss: 1.7975\n",
"Epoch [72/100], Step [200/206], Loss: 1.8560\n",
"Epoch [73/100], Step [100/206], Loss: 1.8892\n",
"Epoch [73/100], Step [200/206], Loss: 1.8031\n",
"Epoch [74/100], Step [100/206], Loss: 1.8147\n",
"Epoch [74/100], Step [200/206], Loss: 1.7783\n",
"Epoch [75/100], Step [100/206], Loss: 1.9001\n",
"Epoch [75/100], Step [200/206], Loss: 1.8443\n",
"Epoch [76/100], Step [100/206], Loss: 1.8809\n",
"Epoch [76/100], Step [200/206], Loss: 1.7478\n",
"Epoch [77/100], Step [100/206], Loss: 1.7588\n",
"Epoch [77/100], Step [200/206], Loss: 1.9256\n",
"Epoch [78/100], Step [100/206], Loss: 1.7766\n",
"Epoch [78/100], Step [200/206], Loss: 1.8041\n",
"Epoch [79/100], Step [100/206], Loss: 1.7835\n",
"Epoch [79/100], Step [200/206], Loss: 1.7466\n",
"Epoch [80/100], Step [100/206], Loss: 1.8114\n",
"Epoch [80/100], Step [200/206], Loss: 1.7969\n",
"Epoch [81/100], Step [100/206], Loss: 1.7738\n",
"Epoch [81/100], Step [200/206], Loss: 1.8206\n",
"Epoch [82/100], Step [100/206], Loss: 1.7776\n",
"Epoch [82/100], Step [200/206], Loss: 1.8304\n",
"Epoch [83/100], Step [100/206], Loss: 1.8344\n",
"Epoch [83/100], Step [200/206], Loss: 1.8109\n",
"Epoch [84/100], Step [100/206], Loss: 1.8179\n",
"Epoch [84/100], Step [200/206], Loss: 1.8043\n",
"Epoch [85/100], Step [100/206], Loss: 1.8179\n",
"Epoch [85/100], Step [200/206], Loss: 1.8325\n",
"Epoch [86/100], Step [100/206], Loss: 1.8183\n",
"Epoch [86/100], Step [200/206], Loss: 1.7328\n",
"Epoch [87/100], Step [100/206], Loss: 1.8620\n",
"Epoch [87/100], Step [200/206], Loss: 1.8000\n",
"Epoch [88/100], Step [100/206], Loss: 1.8483\n",
"Epoch [88/100], Step [200/206], Loss: 1.8043\n",
"Epoch [89/100], Step [100/206], Loss: 1.7693\n",
"Epoch [89/100], Step [200/206], Loss: 1.7481\n",
"Epoch [90/100], Step [100/206], Loss: 1.7677\n",
"Epoch [90/100], Step [200/206], Loss: 1.8045\n",
"Epoch [91/100], Step [100/206], Loss: 1.8099\n",
"Epoch [91/100], Step [200/206], Loss: 1.8099\n",
"Epoch [92/100], Step [100/206], Loss: 1.7886\n",
"Epoch [92/100], Step [200/206], Loss: 1.7793\n",
"Epoch [93/100], Step [100/206], Loss: 1.8118\n",
"Epoch [93/100], Step [200/206], Loss: 1.7652\n",
"Epoch [94/100], Step [100/206], Loss: 1.8959\n",
"Epoch [94/100], Step [200/206], Loss: 1.7397\n",
"Epoch [95/100], Step [100/206], Loss: 1.8318\n",
"Epoch [95/100], Step [200/206], Loss: 1.7495\n",
"Epoch [96/100], Step [100/206], Loss: 1.7748\n",
"Epoch [96/100], Step [200/206], Loss: 1.8398\n",
"Epoch [97/100], Step [100/206], Loss: 1.7984\n",
"Epoch [97/100], Step [200/206], Loss: 1.8364\n",
"Epoch [98/100], Step [100/206], Loss: 1.8306\n",
"Epoch [98/100], Step [200/206], Loss: 1.8240\n",
"Epoch [99/100], Step [100/206], Loss: 1.7435\n",
"Epoch [99/100], Step [200/206], Loss: 1.8296\n",
"Epoch [100/100], Step [100/206], Loss: 1.8072\n",
"Epoch [100/100], Step [200/206], Loss: 1.8313\n"
]
}
],
"source": [
"# Train the model\n",
"total_step = len(train_dataloader)\n",
"for epoch in range(num_epochs):\n",
" for i, (images, labels) in enumerate(train_dataloader):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # Backward and optimize\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n",
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy : 22.62857142857143 %\n"
]
}
],
"source": [
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n",
"with torch.no_grad():\n",
" correct = 0\n",
" total = 0\n",
" for images, labels in test_dataloader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" print('Test Accuracy : {} %'.format(100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_100.pth')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=512, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=7, bias=True)\n",
" (3): LogSoftmax()\n",
" )\n",
")"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = model\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('./checkpoints/weight_resnet18_100.pth'))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/100], Step [100/206], Loss: 1.7892\n",
"Epoch [1/100], Step [200/206], Loss: 1.7876\n",
"Epoch [2/100], Step [100/206], Loss: 1.8238\n",
"Epoch [2/100], Step [200/206], Loss: 1.8310\n",
"Epoch [3/100], Step [100/206], Loss: 1.8300\n",
"Epoch [3/100], Step [200/206], Loss: 1.7602\n",
"Epoch [4/100], Step [100/206], Loss: 1.7329\n",
"Epoch [4/100], Step [200/206], Loss: 1.8731\n",
"Epoch [5/100], Step [100/206], Loss: 1.8559\n",
"Epoch [5/100], Step [200/206], Loss: 1.7999\n",
"Epoch [6/100], Step [100/206], Loss: 1.7815\n",
"Epoch [6/100], Step [200/206], Loss: 1.8430\n",
"Epoch [7/100], Step [100/206], Loss: 1.8338\n",
"Epoch [7/100], Step [200/206], Loss: 1.7600\n",
"Epoch [8/100], Step [100/206], Loss: 1.8517\n",
"Epoch [8/100], Step [200/206], Loss: 1.7685\n",
"Epoch [9/100], Step [100/206], Loss: 1.8468\n",
"Epoch [9/100], Step [200/206], Loss: 1.7668\n",
"Epoch [10/100], Step [100/206], Loss: 1.7533\n",
"Epoch [10/100], Step [200/206], Loss: 1.8353\n",
"Epoch [11/100], Step [100/206], Loss: 1.6990\n",
"Epoch [11/100], Step [200/206], Loss: 1.7079\n",
"Epoch [12/100], Step [100/206], Loss: 1.7996\n",
"Epoch [12/100], Step [200/206], Loss: 1.6895\n",
"Epoch [13/100], Step [100/206], Loss: 1.8325\n",
"Epoch [13/100], Step [200/206], Loss: 1.8564\n",
"Epoch [14/100], Step [100/206], Loss: 1.8215\n",
"Epoch [14/100], Step [200/206], Loss: 1.7708\n",
"Epoch [15/100], Step [100/206], Loss: 1.7893\n",
"Epoch [15/100], Step [200/206], Loss: 1.8371\n",
"Epoch [16/100], Step [100/206], Loss: 1.8978\n",
"Epoch [16/100], Step [200/206], Loss: 1.8038\n",
"Epoch [17/100], Step [100/206], Loss: 1.7541\n",
"Epoch [17/100], Step [200/206], Loss: 1.8134\n",
"Epoch [18/100], Step [100/206], Loss: 1.8766\n",
"Epoch [18/100], Step [200/206], Loss: 1.7279\n",
"Epoch [19/100], Step [100/206], Loss: 1.8450\n",
"Epoch [19/100], Step [200/206], Loss: 1.8058\n",
"Epoch [20/100], Step [100/206], Loss: 1.7848\n",
"Epoch [20/100], Step [200/206], Loss: 1.8017\n",
"Epoch [21/100], Step [100/206], Loss: 1.8013\n",
"Epoch [21/100], Step [200/206], Loss: 1.7861\n",
"Epoch [22/100], Step [100/206], Loss: 1.7494\n",
"Epoch [22/100], Step [200/206], Loss: 1.9154\n",
"Epoch [23/100], Step [100/206], Loss: 1.8243\n",
"Epoch [23/100], Step [200/206], Loss: 1.7085\n",
"Epoch [24/100], Step [100/206], Loss: 1.8059\n",
"Epoch [24/100], Step [200/206], Loss: 1.7298\n",
"Epoch [25/100], Step [100/206], Loss: 1.8759\n",
"Epoch [25/100], Step [200/206], Loss: 1.8032\n",
"Epoch [26/100], Step [100/206], Loss: 1.8111\n",
"Epoch [26/100], Step [200/206], Loss: 1.8276\n",
"Epoch [27/100], Step [100/206], Loss: 1.8080\n",
"Epoch [27/100], Step [200/206], Loss: 1.8107\n",
"Epoch [28/100], Step [100/206], Loss: 1.8482\n",
"Epoch [28/100], Step [200/206], Loss: 1.7860\n",
"Epoch [29/100], Step [100/206], Loss: 1.8291\n",
"Epoch [29/100], Step [200/206], Loss: 1.8582\n",
"Epoch [30/100], Step [100/206], Loss: 1.8089\n",
"Epoch [30/100], Step [200/206], Loss: 1.8188\n",
"Epoch [31/100], Step [100/206], Loss: 1.7331\n",
"Epoch [31/100], Step [200/206], Loss: 1.7531\n",
"Epoch [32/100], Step [100/206], Loss: 1.7825\n",
"Epoch [32/100], Step [200/206], Loss: 1.8719\n",
"Epoch [33/100], Step [100/206], Loss: 1.6788\n",
"Epoch [33/100], Step [200/206], Loss: 1.7896\n",
"Epoch [34/100], Step [100/206], Loss: 1.8505\n",
"Epoch [34/100], Step [200/206], Loss: 1.7788\n",
"Epoch [35/100], Step [100/206], Loss: 1.7491\n",
"Epoch [35/100], Step [200/206], Loss: 1.7202\n",
"Epoch [36/100], Step [100/206], Loss: 1.7676\n",
"Epoch [36/100], Step [200/206], Loss: 1.8237\n",
"Epoch [37/100], Step [100/206], Loss: 1.6978\n",
"Epoch [37/100], Step [200/206], Loss: 1.8277\n",
"Epoch [38/100], Step [100/206], Loss: 1.8590\n",
"Epoch [38/100], Step [200/206], Loss: 1.8558\n",
"Epoch [39/100], Step [100/206], Loss: 1.7079\n",
"Epoch [39/100], Step [200/206], Loss: 1.8212\n",
"Epoch [40/100], Step [100/206], Loss: 1.8058\n",
"Epoch [40/100], Step [200/206], Loss: 1.7185\n",
"Epoch [41/100], Step [100/206], Loss: 1.7763\n",
"Epoch [41/100], Step [200/206], Loss: 1.8191\n",
"Epoch [42/100], Step [100/206], Loss: 1.8503\n",
"Epoch [42/100], Step [200/206], Loss: 1.7860\n",
"Epoch [43/100], Step [100/206], Loss: 1.8735\n",
"Epoch [43/100], Step [200/206], Loss: 1.8357\n",
"Epoch [44/100], Step [100/206], Loss: 1.8115\n",
"Epoch [44/100], Step [200/206], Loss: 1.7708\n",
"Epoch [45/100], Step [100/206], Loss: 1.7588\n",
"Epoch [45/100], Step [200/206], Loss: 1.8141\n",
"Epoch [46/100], Step [100/206], Loss: 1.7575\n",
"Epoch [46/100], Step [200/206], Loss: 1.7897\n",
"Epoch [47/100], Step [100/206], Loss: 1.8437\n",
"Epoch [47/100], Step [200/206], Loss: 1.7795\n",
"Epoch [48/100], Step [100/206], Loss: 1.7964\n",
"Epoch [48/100], Step [200/206], Loss: 1.7802\n",
"Epoch [49/100], Step [100/206], Loss: 1.7255\n",
"Epoch [49/100], Step [200/206], Loss: 1.8108\n",
"Epoch [50/100], Step [100/206], Loss: 1.7756\n",
"Epoch [50/100], Step [200/206], Loss: 1.8098\n",
"Epoch [51/100], Step [100/206], Loss: 1.7472\n",
"Epoch [51/100], Step [200/206], Loss: 1.8115\n",
"Epoch [52/100], Step [100/206], Loss: 1.8203\n",
"Epoch [52/100], Step [200/206], Loss: 1.7780\n",
"Epoch [53/100], Step [100/206], Loss: 1.7752\n",
"Epoch [53/100], Step [200/206], Loss: 1.8196\n",
"Epoch [54/100], Step [100/206], Loss: 1.8122\n",
"Epoch [54/100], Step [200/206], Loss: 1.6767\n",
"Epoch [55/100], Step [100/206], Loss: 1.7651\n",
"Epoch [55/100], Step [200/206], Loss: 1.7838\n",
"Epoch [56/100], Step [100/206], Loss: 1.8838\n",
"Epoch [56/100], Step [200/206], Loss: 1.7270\n",
"Epoch [57/100], Step [100/206], Loss: 1.7036\n",
"Epoch [57/100], Step [200/206], Loss: 1.7990\n",
"Epoch [58/100], Step [100/206], Loss: 1.8059\n",
"Epoch [58/100], Step [200/206], Loss: 1.7053\n",
"Epoch [59/100], Step [100/206], Loss: 1.7308\n",
"Epoch [59/100], Step [200/206], Loss: 1.7947\n",
"Epoch [60/100], Step [100/206], Loss: 1.7685\n",
"Epoch [60/100], Step [200/206], Loss: 1.7190\n",
"Epoch [61/100], Step [100/206], Loss: 1.7363\n",
"Epoch [61/100], Step [200/206], Loss: 1.8444\n",
"Epoch [62/100], Step [100/206], Loss: 1.7901\n",
"Epoch [62/100], Step [200/206], Loss: 1.7741\n",
"Epoch [63/100], Step [100/206], Loss: 1.7831\n",
"Epoch [63/100], Step [200/206], Loss: 1.7854\n",
"Epoch [64/100], Step [100/206], Loss: 1.7566\n",
"Epoch [64/100], Step [200/206], Loss: 1.8495\n",
"Epoch [65/100], Step [100/206], Loss: 1.8012\n",
"Epoch [65/100], Step [200/206], Loss: 1.9272\n",
"Epoch [66/100], Step [100/206], Loss: 1.8067\n",
"Epoch [66/100], Step [200/206], Loss: 1.7367\n",
"Epoch [67/100], Step [100/206], Loss: 1.7847\n",
"Epoch [67/100], Step [200/206], Loss: 1.7979\n",
"Epoch [68/100], Step [100/206], Loss: 1.7676\n",
"Epoch [68/100], Step [200/206], Loss: 1.7234\n",
"Epoch [69/100], Step [100/206], Loss: 1.8200\n",
"Epoch [69/100], Step [200/206], Loss: 1.7297\n",
"Epoch [70/100], Step [100/206], Loss: 1.7916\n",
"Epoch [70/100], Step [200/206], Loss: 1.7731\n",
"Epoch [71/100], Step [100/206], Loss: 1.7446\n",
"Epoch [71/100], Step [200/206], Loss: 1.7522\n",
"Epoch [72/100], Step [100/206], Loss: 1.8136\n",
"Epoch [72/100], Step [200/206], Loss: 1.7659\n",
"Epoch [73/100], Step [100/206], Loss: 1.7140\n",
"Epoch [73/100], Step [200/206], Loss: 1.8071\n",
"Epoch [74/100], Step [100/206], Loss: 1.7331\n",
"Epoch [74/100], Step [200/206], Loss: 1.7997\n",
"Epoch [75/100], Step [100/206], Loss: 1.8192\n",
"Epoch [75/100], Step [200/206], Loss: 1.6801\n",
"Epoch [76/100], Step [100/206], Loss: 1.7936\n",
"Epoch [76/100], Step [200/206], Loss: 1.7079\n",
"Epoch [77/100], Step [100/206], Loss: 1.8367\n",
"Epoch [77/100], Step [200/206], Loss: 1.6948\n",
"Epoch [78/100], Step [100/206], Loss: 1.7912\n",
"Epoch [78/100], Step [200/206], Loss: 1.8324\n",
"Epoch [79/100], Step [100/206], Loss: 1.7412\n",
"Epoch [79/100], Step [200/206], Loss: 1.7192\n",
"Epoch [80/100], Step [100/206], Loss: 1.7997\n",
"Epoch [80/100], Step [200/206], Loss: 1.7341\n",
"Epoch [81/100], Step [100/206], Loss: 1.8008\n",
"Epoch [81/100], Step [200/206], Loss: 1.8107\n",
"Epoch [82/100], Step [100/206], Loss: 1.7701\n",
"Epoch [82/100], Step [200/206], Loss: 1.7862\n",
"Epoch [83/100], Step [100/206], Loss: 1.7311\n",
"Epoch [83/100], Step [200/206], Loss: 1.8078\n",
"Epoch [84/100], Step [100/206], Loss: 1.7365\n",
"Epoch [84/100], Step [200/206], Loss: 1.7675\n",
"Epoch [85/100], Step [100/206], Loss: 1.8161\n",
"Epoch [85/100], Step [200/206], Loss: 1.8275\n",
"Epoch [86/100], Step [100/206], Loss: 1.7562\n",
"Epoch [86/100], Step [200/206], Loss: 1.7723\n",
"Epoch [87/100], Step [100/206], Loss: 1.7087\n",
"Epoch [87/100], Step [200/206], Loss: 1.7859\n",
"Epoch [88/100], Step [100/206], Loss: 1.7897\n",
"Epoch [88/100], Step [200/206], Loss: 1.7671\n",
"Epoch [89/100], Step [100/206], Loss: 1.7756\n",
"Epoch [89/100], Step [200/206], Loss: 1.7493\n",
"Epoch [90/100], Step [100/206], Loss: 1.7414\n",
"Epoch [90/100], Step [200/206], Loss: 1.7378\n",
"Epoch [91/100], Step [100/206], Loss: 1.6846\n",
"Epoch [91/100], Step [200/206], Loss: 1.7964\n",
"Epoch [92/100], Step [100/206], Loss: 1.7420\n",
"Epoch [92/100], Step [200/206], Loss: 1.7729\n",
"Epoch [93/100], Step [100/206], Loss: 1.8605\n",
"Epoch [93/100], Step [200/206], Loss: 1.7876\n",
"Epoch [94/100], Step [100/206], Loss: 1.7263\n",
"Epoch [94/100], Step [200/206], Loss: 1.7625\n",
"Epoch [95/100], Step [100/206], Loss: 1.7476\n",
"Epoch [95/100], Step [200/206], Loss: 1.8162\n",
"Epoch [96/100], Step [100/206], Loss: 1.6556\n",
"Epoch [96/100], Step [200/206], Loss: 1.7610\n",
"Epoch [97/100], Step [100/206], Loss: 1.7237\n",
"Epoch [97/100], Step [200/206], Loss: 1.7982\n",
"Epoch [98/100], Step [100/206], Loss: 1.7569\n",
"Epoch [98/100], Step [200/206], Loss: 1.7443\n",
"Epoch [99/100], Step [100/206], Loss: 1.8233\n",
"Epoch [99/100], Step [200/206], Loss: 1.7540\n",
"Epoch [100/100], Step [100/206], Loss: 1.7758\n",
"Epoch [100/100], Step [200/206], Loss: 1.7109\n"
]
}
],
"source": [
"# Train the model\n",
"total_step = len(train_dataloader)\n",
"for epoch in range(num_epochs):\n",
" for i, (images, labels) in enumerate(train_dataloader):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # Backward and optimize\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n",
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy : 21.914285714285715 %\n"
]
}
],
"source": [
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n",
"with torch.no_grad():\n",
" correct = 0\n",
" total = 0\n",
" for images, labels in test_dataloader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" print('Test Accuracy : {} %'.format(100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_200.pth')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('./checkpoints/weight_resnet18_200.pth'))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/100], Step [100/206], Loss: 1.7775\n",
"Epoch [1/100], Step [200/206], Loss: 1.7994\n",
"Epoch [2/100], Step [100/206], Loss: 1.7386\n",
"Epoch [2/100], Step [200/206], Loss: 1.7325\n",
"Epoch [3/100], Step [100/206], Loss: 1.7318\n",
"Epoch [3/100], Step [200/206], Loss: 1.7492\n",
"Epoch [4/100], Step [100/206], Loss: 1.6084\n",
"Epoch [4/100], Step [200/206], Loss: 1.7708\n",
"Epoch [5/100], Step [100/206], Loss: 1.7332\n",
"Epoch [5/100], Step [200/206], Loss: 1.7659\n",
"Epoch [6/100], Step [100/206], Loss: 1.7814\n",
"Epoch [6/100], Step [200/206], Loss: 1.7406\n",
"Epoch [7/100], Step [100/206], Loss: 1.7844\n",
"Epoch [7/100], Step [200/206], Loss: 1.7175\n",
"Epoch [8/100], Step [100/206], Loss: 1.8441\n",
"Epoch [8/100], Step [200/206], Loss: 1.7527\n",
"Epoch [9/100], Step [100/206], Loss: 1.8416\n",
"Epoch [9/100], Step [200/206], Loss: 1.8365\n",
"Epoch [10/100], Step [100/206], Loss: 1.6929\n",
"Epoch [10/100], Step [200/206], Loss: 1.7640\n",
"Epoch [11/100], Step [100/206], Loss: 1.7789\n",
"Epoch [11/100], Step [200/206], Loss: 1.8365\n",
"Epoch [12/100], Step [100/206], Loss: 1.7304\n",
"Epoch [12/100], Step [200/206], Loss: 1.7912\n",
"Epoch [13/100], Step [100/206], Loss: 1.8394\n",
"Epoch [13/100], Step [200/206], Loss: 1.6970\n",
"Epoch [14/100], Step [100/206], Loss: 1.8486\n",
"Epoch [14/100], Step [200/206], Loss: 1.8152\n",
"Epoch [15/100], Step [100/206], Loss: 1.8517\n",
"Epoch [15/100], Step [200/206], Loss: 1.7902\n",
"Epoch [16/100], Step [100/206], Loss: 1.7725\n",
"Epoch [16/100], Step [200/206], Loss: 1.7615\n",
"Epoch [17/100], Step [100/206], Loss: 1.7559\n",
"Epoch [17/100], Step [200/206], Loss: 1.7292\n",
"Epoch [18/100], Step [100/206], Loss: 1.8214\n",
"Epoch [18/100], Step [200/206], Loss: 1.7762\n",
"Epoch [19/100], Step [100/206], Loss: 1.7168\n",
"Epoch [19/100], Step [200/206], Loss: 1.7787\n",
"Epoch [20/100], Step [100/206], Loss: 1.7660\n",
"Epoch [20/100], Step [200/206], Loss: 1.7539\n",
"Epoch [21/100], Step [100/206], Loss: 1.7864\n",
"Epoch [21/100], Step [200/206], Loss: 1.7234\n",
"Epoch [22/100], Step [100/206], Loss: 1.8251\n",
"Epoch [22/100], Step [200/206], Loss: 1.7325\n",
"Epoch [23/100], Step [100/206], Loss: 1.7693\n",
"Epoch [23/100], Step [200/206], Loss: 1.7419\n",
"Epoch [24/100], Step [100/206], Loss: 1.6545\n",
"Epoch [24/100], Step [200/206], Loss: 1.7839\n",
"Epoch [25/100], Step [100/206], Loss: 1.7508\n",
"Epoch [25/100], Step [200/206], Loss: 1.7846\n",
"Epoch [26/100], Step [100/206], Loss: 1.8689\n",
"Epoch [26/100], Step [200/206], Loss: 1.7939\n",
"Epoch [27/100], Step [100/206], Loss: 1.8019\n",
"Epoch [27/100], Step [200/206], Loss: 1.8169\n",
"Epoch [28/100], Step [100/206], Loss: 1.7630\n",
"Epoch [28/100], Step [200/206], Loss: 1.7747\n",
"Epoch [29/100], Step [100/206], Loss: 1.8262\n",
"Epoch [29/100], Step [200/206], Loss: 1.8054\n",
"Epoch [30/100], Step [100/206], Loss: 1.8358\n",
"Epoch [30/100], Step [200/206], Loss: 1.7541\n",
"Epoch [31/100], Step [100/206], Loss: 1.7286\n",
"Epoch [31/100], Step [200/206], Loss: 1.7669\n",
"Epoch [32/100], Step [100/206], Loss: 1.7877\n",
"Epoch [32/100], Step [200/206], Loss: 1.7734\n",
"Epoch [33/100], Step [100/206], Loss: 1.6754\n",
"Epoch [33/100], Step [200/206], Loss: 1.7035\n",
"Epoch [34/100], Step [100/206], Loss: 1.8467\n",
"Epoch [34/100], Step [200/206], Loss: 1.6976\n",
"Epoch [35/100], Step [100/206], Loss: 1.7958\n",
"Epoch [35/100], Step [200/206], Loss: 1.8057\n",
"Epoch [36/100], Step [100/206], Loss: 1.7282\n",
"Epoch [36/100], Step [200/206], Loss: 1.7345\n",
"Epoch [37/100], Step [100/206], Loss: 1.6798\n",
"Epoch [37/100], Step [200/206], Loss: 1.7780\n",
"Epoch [38/100], Step [100/206], Loss: 1.7703\n",
"Epoch [38/100], Step [200/206], Loss: 1.7031\n",
"Epoch [39/100], Step [100/206], Loss: 1.7540\n",
"Epoch [39/100], Step [200/206], Loss: 1.8360\n",
"Epoch [40/100], Step [100/206], Loss: 1.7261\n",
"Epoch [40/100], Step [200/206], Loss: 1.7430\n",
"Epoch [41/100], Step [100/206], Loss: 1.8863\n",
"Epoch [41/100], Step [200/206], Loss: 1.7337\n",
"Epoch [42/100], Step [100/206], Loss: 1.7018\n",
"Epoch [42/100], Step [200/206], Loss: 1.6827\n",
"Epoch [43/100], Step [100/206], Loss: 1.7575\n",
"Epoch [43/100], Step [200/206], Loss: 1.8266\n",
"Epoch [44/100], Step [100/206], Loss: 1.6754\n",
"Epoch [44/100], Step [200/206], Loss: 1.7436\n",
"Epoch [45/100], Step [100/206], Loss: 1.7745\n",
"Epoch [45/100], Step [200/206], Loss: 1.7723\n",
"Epoch [46/100], Step [100/206], Loss: 1.7355\n",
"Epoch [46/100], Step [200/206], Loss: 1.7777\n",
"Epoch [47/100], Step [100/206], Loss: 1.8403\n",
"Epoch [47/100], Step [200/206], Loss: 1.7781\n",
"Epoch [48/100], Step [100/206], Loss: 1.8090\n",
"Epoch [48/100], Step [200/206], Loss: 1.7719\n",
"Epoch [49/100], Step [100/206], Loss: 1.8550\n",
"Epoch [49/100], Step [200/206], Loss: 1.6970\n",
"Epoch [50/100], Step [100/206], Loss: 1.6394\n",
"Epoch [50/100], Step [200/206], Loss: 1.8137\n",
"Epoch [51/100], Step [100/206], Loss: 1.7478\n",
"Epoch [51/100], Step [200/206], Loss: 1.7640\n",
"Epoch [52/100], Step [100/206], Loss: 1.8412\n",
"Epoch [52/100], Step [200/206], Loss: 1.8284\n",
"Epoch [53/100], Step [100/206], Loss: 1.7049\n",
"Epoch [53/100], Step [200/206], Loss: 1.7652\n",
"Epoch [54/100], Step [100/206], Loss: 1.7729\n",
"Epoch [54/100], Step [200/206], Loss: 1.8220\n",
"Epoch [55/100], Step [100/206], Loss: 1.8077\n",
"Epoch [55/100], Step [200/206], Loss: 1.7978\n",
"Epoch [56/100], Step [100/206], Loss: 1.7188\n",
"Epoch [56/100], Step [200/206], Loss: 1.6348\n",
"Epoch [57/100], Step [100/206], Loss: 1.7600\n",
"Epoch [57/100], Step [200/206], Loss: 1.7314\n",
"Epoch [58/100], Step [100/206], Loss: 1.7389\n",
"Epoch [58/100], Step [200/206], Loss: 1.8066\n",
"Epoch [59/100], Step [100/206], Loss: 1.7768\n",
"Epoch [59/100], Step [200/206], Loss: 1.8010\n",
"Epoch [60/100], Step [100/206], Loss: 1.7392\n",
"Epoch [60/100], Step [200/206], Loss: 1.8117\n",
"Epoch [61/100], Step [100/206], Loss: 1.6771\n",
"Epoch [61/100], Step [200/206], Loss: 1.7958\n",
"Epoch [62/100], Step [100/206], Loss: 1.7026\n",
"Epoch [62/100], Step [200/206], Loss: 1.7667\n",
"Epoch [63/100], Step [100/206], Loss: 1.8149\n",
"Epoch [63/100], Step [200/206], Loss: 1.6892\n",
"Epoch [64/100], Step [100/206], Loss: 1.7962\n",
"Epoch [64/100], Step [200/206], Loss: 1.7722\n",
"Epoch [65/100], Step [100/206], Loss: 1.6924\n",
"Epoch [65/100], Step [200/206], Loss: 1.7129\n",
"Epoch [66/100], Step [100/206], Loss: 1.7348\n",
"Epoch [66/100], Step [200/206], Loss: 1.7210\n",
"Epoch [67/100], Step [100/206], Loss: 1.6795\n",
"Epoch [67/100], Step [200/206], Loss: 1.7293\n",
"Epoch [68/100], Step [100/206], Loss: 1.7481\n",
"Epoch [68/100], Step [200/206], Loss: 1.7611\n",
"Epoch [69/100], Step [100/206], Loss: 1.7859\n",
"Epoch [69/100], Step [200/206], Loss: 1.7177\n",
"Epoch [70/100], Step [100/206], Loss: 1.7398\n",
"Epoch [70/100], Step [200/206], Loss: 1.8317\n",
"Epoch [71/100], Step [100/206], Loss: 1.8460\n",
"Epoch [71/100], Step [200/206], Loss: 1.7464\n",
"Epoch [72/100], Step [100/206], Loss: 1.7416\n",
"Epoch [72/100], Step [200/206], Loss: 1.6381\n",
"Epoch [73/100], Step [100/206], Loss: 1.8355\n",
"Epoch [73/100], Step [200/206], Loss: 1.7312\n",
"Epoch [74/100], Step [100/206], Loss: 1.7605\n",
"Epoch [74/100], Step [200/206], Loss: 1.7273\n",
"Epoch [75/100], Step [100/206], Loss: 1.7278\n",
"Epoch [75/100], Step [200/206], Loss: 1.7584\n",
"Epoch [76/100], Step [100/206], Loss: 1.7839\n",
"Epoch [76/100], Step [200/206], Loss: 1.8740\n",
"Epoch [77/100], Step [100/206], Loss: 1.7364\n",
"Epoch [77/100], Step [200/206], Loss: 1.7574\n",
"Epoch [78/100], Step [100/206], Loss: 1.7734\n",
"Epoch [78/100], Step [200/206], Loss: 1.7331\n",
"Epoch [79/100], Step [100/206], Loss: 1.7919\n",
"Epoch [79/100], Step [200/206], Loss: 1.8142\n",
"Epoch [80/100], Step [100/206], Loss: 1.6907\n",
"Epoch [80/100], Step [200/206], Loss: 1.7179\n",
"Epoch [81/100], Step [100/206], Loss: 1.8401\n",
"Epoch [81/100], Step [200/206], Loss: 1.7304\n",
"Epoch [82/100], Step [100/206], Loss: 1.6853\n",
"Epoch [82/100], Step [200/206], Loss: 1.7725\n",
"Epoch [83/100], Step [100/206], Loss: 1.6901\n",
"Epoch [83/100], Step [200/206], Loss: 1.7018\n",
"Epoch [84/100], Step [100/206], Loss: 1.7844\n",
"Epoch [84/100], Step [200/206], Loss: 1.8158\n",
"Epoch [85/100], Step [100/206], Loss: 1.7045\n",
"Epoch [85/100], Step [200/206], Loss: 1.8302\n",
"Epoch [86/100], Step [100/206], Loss: 1.7720\n",
"Epoch [86/100], Step [200/206], Loss: 1.8571\n",
"Epoch [87/100], Step [100/206], Loss: 1.7617\n",
"Epoch [87/100], Step [200/206], Loss: 1.6629\n",
"Epoch [88/100], Step [100/206], Loss: 1.7431\n",
"Epoch [88/100], Step [200/206], Loss: 1.7406\n",
"Epoch [89/100], Step [100/206], Loss: 1.8519\n",
"Epoch [89/100], Step [200/206], Loss: 1.7814\n",
"Epoch [90/100], Step [100/206], Loss: 1.7153\n",
"Epoch [90/100], Step [200/206], Loss: 1.7120\n",
"Epoch [91/100], Step [100/206], Loss: 1.7705\n",
"Epoch [91/100], Step [200/206], Loss: 1.7705\n",
"Epoch [92/100], Step [100/206], Loss: 1.6527\n",
"Epoch [92/100], Step [200/206], Loss: 1.7528\n",
"Epoch [93/100], Step [100/206], Loss: 1.7260\n",
"Epoch [93/100], Step [200/206], Loss: 1.7796\n",
"Epoch [94/100], Step [100/206], Loss: 1.6830\n",
"Epoch [94/100], Step [200/206], Loss: 1.7714\n",
"Epoch [95/100], Step [100/206], Loss: 1.6935\n",
"Epoch [95/100], Step [200/206], Loss: 1.7616\n",
"Epoch [96/100], Step [100/206], Loss: 1.8042\n",
"Epoch [96/100], Step [200/206], Loss: 1.7847\n",
"Epoch [97/100], Step [100/206], Loss: 1.7990\n",
"Epoch [97/100], Step [200/206], Loss: 1.7717\n",
"Epoch [98/100], Step [100/206], Loss: 1.7043\n",
"Epoch [98/100], Step [200/206], Loss: 1.7463\n",
"Epoch [99/100], Step [100/206], Loss: 1.8212\n",
"Epoch [99/100], Step [200/206], Loss: 1.7586\n",
"Epoch [100/100], Step [100/206], Loss: 1.7591\n",
"Epoch [100/100], Step [200/206], Loss: 1.7037\n"
]
}
],
"source": [
"# Train the model\n",
"total_step = len(train_dataloader)\n",
"for epoch in range(num_epochs):\n",
" for i, (images, labels) in enumerate(train_dataloader):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # Backward and optimize\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n",
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), './checkpoints/weight_resnet18_300.pth')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy : 21.714285714285715 %\n"
]
}
],
"source": [
"model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n",
"with torch.no_grad():\n",
" correct = 0\n",
" total = 0\n",
" for images, labels in test_dataloader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" print('Test Accuracy : {} %'.format(100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# model.save_weights('./checkpoints')\n",
"torch.save(model.state_dict(), './checkpoints/weight.pth')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=512, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=7, bias=True)\n",
" (3): LogSoftmax()\n",
" )\n",
")"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = model\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('./checkpoints/weight.pth'))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 001/100 | Train: 30.198% | Loss: 0.014\n",
"Time elapsed: 1.94 min\n",
"Epoch: 002/100 | Train: 30.053% | Loss: 0.014\n",
"Time elapsed: 3.98 min\n",
"Epoch: 003/100 | Train: 30.008% | Loss: 0.014\n",
"Time elapsed: 6.02 min\n",
"Epoch: 004/100 | Train: 29.116% | Loss: 0.014\n",
"Time elapsed: 7.98 min\n",
"Epoch: 005/100 | Train: 29.398% | Loss: 0.014\n",
"Time elapsed: 9.91 min\n",
"Epoch: 006/100 | Train: 29.246% | Loss: 0.014\n",
"Time elapsed: 11.87 min\n",
"Epoch: 007/100 | Train: 30.590% | Loss: 0.014\n",
"Time elapsed: 13.86 min\n",
"Epoch: 008/100 | Train: 29.497% | Loss: 0.014\n",
"Time elapsed: 15.88 min\n",
"Epoch: 009/100 | Train: 28.693% | Loss: 0.014\n",
"Time elapsed: 17.95 min\n",
"Epoch: 010/100 | Train: 30.141% | Loss: 0.014\n",
"Time elapsed: 20.11 min\n",
"Epoch: 011/100 | Train: 29.025% | Loss: 0.014\n",
"Time elapsed: 22.04 min\n",
"Epoch: 012/100 | Train: 29.318% | Loss: 0.014\n",
"Time elapsed: 23.84 min\n",
"Epoch: 013/100 | Train: 30.545% | Loss: 0.014\n",
"Time elapsed: 25.67 min\n",
"Epoch: 014/100 | Train: 30.366% | Loss: 0.014\n",
"Time elapsed: 27.61 min\n",
"Epoch: 015/100 | Train: 30.278% | Loss: 0.014\n",
"Time elapsed: 29.42 min\n",
"Epoch: 016/100 | Train: 29.863% | Loss: 0.014\n",
"Time elapsed: 31.23 min\n",
"Epoch: 017/100 | Train: 29.745% | Loss: 0.014\n",
"Time elapsed: 33.00 min\n",
"Epoch: 018/100 | Train: 29.291% | Loss: 0.014\n",
"Time elapsed: 34.78 min\n",
"Epoch: 019/100 | Train: 30.731% | Loss: 0.014\n",
"Time elapsed: 36.57 min\n",
"Epoch: 020/100 | Train: 29.832% | Loss: 0.014\n",
"Time elapsed: 38.35 min\n",
"Epoch: 021/100 | Train: 30.830% | Loss: 0.014\n",
"Time elapsed: 40.13 min\n",
"Epoch: 022/100 | Train: 29.288% | Loss: 0.014\n",
"Time elapsed: 41.90 min\n",
"Epoch: 023/100 | Train: 29.630% | Loss: 0.014\n",
"Time elapsed: 43.67 min\n",
"Epoch: 024/100 | Train: 31.269% | Loss: 0.014\n",
"Time elapsed: 45.45 min\n",
"Epoch: 025/100 | Train: 31.044% | Loss: 0.014\n",
"Time elapsed: 47.21 min\n",
"Epoch: 026/100 | Train: 30.564% | Loss: 0.014\n",
"Time elapsed: 49.01 min\n",
"Epoch: 027/100 | Train: 30.568% | Loss: 0.014\n",
"Time elapsed: 50.80 min\n",
"Epoch: 028/100 | Train: 28.514% | Loss: 0.014\n",
"Time elapsed: 52.57 min\n",
"Epoch: 029/100 | Train: 31.093% | Loss: 0.014\n",
"Time elapsed: 54.37 min\n",
"Epoch: 030/100 | Train: 30.842% | Loss: 0.014\n",
"Time elapsed: 56.21 min\n",
"Epoch: 031/100 | Train: 29.550% | Loss: 0.014\n",
"Time elapsed: 58.02 min\n",
"Epoch: 032/100 | Train: 30.027% | Loss: 0.014\n",
"Time elapsed: 59.81 min\n",
"Epoch: 033/100 | Train: 30.293% | Loss: 0.014\n",
"Time elapsed: 61.60 min\n",
"Epoch: 034/100 | Train: 30.869% | Loss: 0.014\n",
"Time elapsed: 63.40 min\n",
"Epoch: 035/100 | Train: 30.838% | Loss: 0.014\n",
"Time elapsed: 65.18 min\n",
"Epoch: 036/100 | Train: 30.899% | Loss: 0.014\n",
"Time elapsed: 66.99 min\n",
"Epoch: 037/100 | Train: 30.450% | Loss: 0.014\n",
"Time elapsed: 68.80 min\n",
"Epoch: 038/100 | Train: 30.758% | Loss: 0.014\n",
"Time elapsed: 70.60 min\n",
"Epoch: 039/100 | Train: 30.907% | Loss: 0.014\n",
"Time elapsed: 72.40 min\n",
"Epoch: 040/100 | Train: 30.488% | Loss: 0.014\n",
"Time elapsed: 74.21 min\n",
"Epoch: 041/100 | Train: 29.303% | Loss: 0.014\n",
"Time elapsed: 76.01 min\n",
"Epoch: 042/100 | Train: 29.657% | Loss: 0.014\n",
"Time elapsed: 77.79 min\n",
"Epoch: 043/100 | Train: 30.549% | Loss: 0.014\n",
"Time elapsed: 79.59 min\n",
"Epoch: 044/100 | Train: 30.922% | Loss: 0.014\n",
"Time elapsed: 81.39 min\n",
"Epoch: 045/100 | Train: 30.663% | Loss: 0.014\n",
"Time elapsed: 83.17 min\n",
"Epoch: 046/100 | Train: 30.930% | Loss: 0.014\n",
"Time elapsed: 84.97 min\n",
"Epoch: 047/100 | Train: 29.642% | Loss: 0.014\n",
"Time elapsed: 86.73 min\n",
"Epoch: 048/100 | Train: 31.090% | Loss: 0.014\n",
"Time elapsed: 88.53 min\n",
"Epoch: 049/100 | Train: 31.166% | Loss: 0.014\n",
"Time elapsed: 90.31 min\n",
"Epoch: 050/100 | Train: 31.192% | Loss: 0.014\n",
"Time elapsed: 92.09 min\n",
"Epoch: 051/100 | Train: 30.263% | Loss: 0.014\n",
"Time elapsed: 93.85 min\n",
"Epoch: 052/100 | Train: 29.901% | Loss: 0.014\n",
"Time elapsed: 95.63 min\n",
"Epoch: 053/100 | Train: 31.059% | Loss: 0.014\n",
"Time elapsed: 97.44 min\n",
"Epoch: 054/100 | Train: 31.326% | Loss: 0.014\n",
"Time elapsed: 99.22 min\n",
"Epoch: 055/100 | Train: 30.507% | Loss: 0.014\n",
"Time elapsed: 101.04 min\n",
"Epoch: 056/100 | Train: 30.613% | Loss: 0.014\n",
"Time elapsed: 102.89 min\n",
"Epoch: 057/100 | Train: 29.848% | Loss: 0.014\n",
"Time elapsed: 104.72 min\n",
"Epoch: 058/100 | Train: 30.712% | Loss: 0.014\n",
"Time elapsed: 106.52 min\n",
"Epoch: 059/100 | Train: 29.440% | Loss: 0.014\n",
"Time elapsed: 108.29 min\n",
"Epoch: 060/100 | Train: 30.011% | Loss: 0.014\n",
"Time elapsed: 110.11 min\n",
"Epoch: 061/100 | Train: 31.230% | Loss: 0.014\n",
"Time elapsed: 111.89 min\n",
"Epoch: 062/100 | Train: 30.747% | Loss: 0.014\n",
"Time elapsed: 113.67 min\n",
"Epoch: 063/100 | Train: 29.230% | Loss: 0.014\n",
"Time elapsed: 115.44 min\n",
"Epoch: 064/100 | Train: 29.570% | Loss: 0.014\n",
"Time elapsed: 117.24 min\n",
"Epoch: 065/100 | Train: 29.973% | Loss: 0.014\n",
"Time elapsed: 119.02 min\n",
"Epoch: 066/100 | Train: 30.526% | Loss: 0.014\n",
"Time elapsed: 120.77 min\n",
"Epoch: 067/100 | Train: 30.663% | Loss: 0.014\n",
"Time elapsed: 122.54 min\n",
"Epoch: 068/100 | Train: 30.766% | Loss: 0.014\n",
"Time elapsed: 124.30 min\n",
"Epoch: 069/100 | Train: 31.196% | Loss: 0.014\n",
"Time elapsed: 126.08 min\n",
"Epoch: 070/100 | Train: 30.716% | Loss: 0.014\n",
"Time elapsed: 127.86 min\n",
"Epoch: 071/100 | Train: 30.983% | Loss: 0.014\n",
"Time elapsed: 129.66 min\n",
"Epoch: 072/100 | Train: 31.474% | Loss: 0.014\n",
"Time elapsed: 131.47 min\n",
"Epoch: 073/100 | Train: 31.173% | Loss: 0.014\n",
"Time elapsed: 133.27 min\n",
"Epoch: 074/100 | Train: 30.617% | Loss: 0.014\n",
"Time elapsed: 135.06 min\n",
"Epoch: 075/100 | Train: 31.451% | Loss: 0.014\n",
"Time elapsed: 136.89 min\n",
"Epoch: 076/100 | Train: 30.674% | Loss: 0.014\n",
"Time elapsed: 138.69 min\n",
"Epoch: 077/100 | Train: 30.678% | Loss: 0.014\n",
"Time elapsed: 140.51 min\n",
"Epoch: 078/100 | Train: 31.295% | Loss: 0.014\n",
"Time elapsed: 142.31 min\n",
"Epoch: 079/100 | Train: 30.583% | Loss: 0.014\n",
"Time elapsed: 144.11 min\n",
"Epoch: 080/100 | Train: 31.341% | Loss: 0.014\n",
"Time elapsed: 145.90 min\n",
"Epoch: 081/100 | Train: 29.307% | Loss: 0.014\n",
"Time elapsed: 147.70 min\n",
"Epoch: 082/100 | Train: 31.048% | Loss: 0.014\n",
"Time elapsed: 149.53 min\n",
"Epoch: 083/100 | Train: 31.128% | Loss: 0.014\n",
"Time elapsed: 151.34 min\n",
"Epoch: 084/100 | Train: 30.682% | Loss: 0.014\n",
"Time elapsed: 153.14 min\n",
"Epoch: 085/100 | Train: 31.589% | Loss: 0.014\n",
"Time elapsed: 154.94 min\n",
"Epoch: 086/100 | Train: 31.832% | Loss: 0.014\n",
"Time elapsed: 156.73 min\n",
"Epoch: 087/100 | Train: 31.676% | Loss: 0.014\n",
"Time elapsed: 158.54 min\n",
"Epoch: 088/100 | Train: 31.535% | Loss: 0.014\n",
"Time elapsed: 160.40 min\n",
"Epoch: 089/100 | Train: 30.206% | Loss: 0.014\n",
"Time elapsed: 162.21 min\n",
"Epoch: 090/100 | Train: 30.686% | Loss: 0.014\n",
"Time elapsed: 164.03 min\n",
"Epoch: 091/100 | Train: 30.880% | Loss: 0.014\n",
"Time elapsed: 165.84 min\n",
"Epoch: 092/100 | Train: 31.509% | Loss: 0.014\n",
"Time elapsed: 167.65 min\n",
"Epoch: 093/100 | Train: 30.434% | Loss: 0.014\n",
"Time elapsed: 169.46 min\n",
"Epoch: 094/100 | Train: 30.712% | Loss: 0.014\n",
"Time elapsed: 171.28 min\n",
"Epoch: 095/100 | Train: 30.945% | Loss: 0.014\n",
"Time elapsed: 173.10 min\n",
"Epoch: 096/100 | Train: 31.547% | Loss: 0.014\n",
"Time elapsed: 174.89 min\n",
"Epoch: 097/100 | Train: 31.680% | Loss: 0.014\n",
"Time elapsed: 176.69 min\n",
"Epoch: 098/100 | Train: 31.710% | Loss: 0.014\n",
"Time elapsed: 178.49 min\n",
"Epoch: 099/100 | Train: 29.208% | Loss: 0.014\n",
"Time elapsed: 180.32 min\n",
"Epoch: 100/100 | Train: 30.259% | Loss: 0.014\n",
"Time elapsed: 182.14 min\n",
"Total Training Time: 182.14 min\n"
]
}
],
"source": [
"num_epochs = 100\n",
"import time \n",
"\n",
"start_time = time.time()\n",
"for epoch in range(num_epochs):\n",
" \n",
" model.train()\n",
" for batch_idx, (features, targets) in enumerate(train_dataloader):\n",
" \n",
" features = features.to(device)\n",
"# print(features.shape)\n",
"# break\n",
" targets = targets.to(device)\n",
" \n",
" ### FORWARD AND BACK PROP\n",
"# logits, probas = model(features)\n",
"# cost = F.cross_entropy(logits, targets)\n",
"\n",
" predicted = model(features)\n",
" loss = criterion(predicted,targets)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"# cost.backward()\n",
" \n",
" ### UPDATE MODEL PARAMETERS\n",
" optimizer.step()\n",
" \n",
" ### LOGGING\n",
"# if not batch_idx % 50:\n",
"# print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
"# %(epoch+1, num_epochs, batch_idx, \n",
"# len(train_dataloader), cost))\n",
"\n",
" model.eval()\n",
" with torch.set_grad_enabled(False): # save memory during inference\n",
" print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n",
" epoch+1, num_epochs, \n",
" compute_accuracy(model, train_dataloader),\n",
" compute_epoch_loss(model, train_dataloader)))\n",
"\n",
"\n",
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
" \n",
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=512, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=7, bias=True)\n",
" (3): LogSoftmax()\n",
" )\n",
")"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = model\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('./checkpoints/weight.pth'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 001/100 | Train: 30.678% | Loss: 0.014\n",
"Time elapsed: 1.81 min\n",
"Epoch: 002/100 | Train: 30.590% | Loss: 0.014\n",
"Time elapsed: 3.67 min\n",
"Epoch: 003/100 | Train: 32.015% | Loss: 0.014\n",
"Time elapsed: 5.49 min\n",
"Epoch: 004/100 | Train: 31.326% | Loss: 0.014\n",
"Time elapsed: 7.30 min\n",
"Epoch: 005/100 | Train: 31.718% | Loss: 0.014\n",
"Time elapsed: 9.07 min\n",
"Epoch: 006/100 | Train: 31.265% | Loss: 0.014\n",
"Time elapsed: 10.83 min\n",
"Epoch: 007/100 | Train: 31.688% | Loss: 0.014\n",
"Time elapsed: 12.61 min\n",
"Epoch: 008/100 | Train: 31.451% | Loss: 0.014\n",
"Time elapsed: 14.39 min\n",
"Epoch: 009/100 | Train: 31.512% | Loss: 0.014\n",
"Time elapsed: 16.22 min\n",
"Epoch: 010/100 | Train: 31.764% | Loss: 0.014\n",
"Time elapsed: 18.05 min\n",
"Epoch: 011/100 | Train: 31.040% | Loss: 0.014\n",
"Time elapsed: 19.99 min\n",
"Epoch: 012/100 | Train: 29.840% | Loss: 0.014\n",
"Time elapsed: 21.95 min\n",
"Epoch: 013/100 | Train: 31.920% | Loss: 0.014\n",
"Time elapsed: 23.88 min\n",
"Epoch: 014/100 | Train: 29.722% | Loss: 0.014\n",
"Time elapsed: 25.71 min\n",
"Epoch: 015/100 | Train: 31.931% | Loss: 0.014\n",
"Time elapsed: 27.67 min\n",
"Epoch: 016/100 | Train: 31.653% | Loss: 0.014\n",
"Time elapsed: 29.49 min\n",
"Epoch: 017/100 | Train: 30.358% | Loss: 0.014\n",
"Time elapsed: 31.43 min\n",
"Epoch: 018/100 | Train: 31.322% | Loss: 0.014\n",
"Time elapsed: 33.31 min\n",
"Epoch: 019/100 | Train: 31.604% | Loss: 0.014\n",
"Time elapsed: 35.20 min\n",
"Epoch: 020/100 | Train: 31.608% | Loss: 0.014\n",
"Time elapsed: 37.25 min\n",
"Epoch: 021/100 | Train: 32.156% | Loss: 0.014\n",
"Time elapsed: 39.25 min\n",
"Epoch: 022/100 | Train: 31.893% | Loss: 0.014\n",
"Time elapsed: 41.29 min\n",
"Epoch: 023/100 | Train: 31.840% | Loss: 0.014\n",
"Time elapsed: 43.31 min\n",
"Epoch: 024/100 | Train: 32.370% | Loss: 0.014\n",
"Time elapsed: 45.35 min\n",
"Epoch: 025/100 | Train: 31.855% | Loss: 0.014\n",
"Time elapsed: 47.39 min\n",
"Epoch: 026/100 | Train: 31.783% | Loss: 0.014\n",
"Time elapsed: 49.39 min\n",
"Epoch: 027/100 | Train: 31.840% | Loss: 0.014\n",
"Time elapsed: 51.40 min\n",
"Epoch: 028/100 | Train: 31.566% | Loss: 0.014\n",
"Time elapsed: 53.46 min\n",
"Epoch: 029/100 | Train: 31.669% | Loss: 0.014\n",
"Time elapsed: 55.52 min\n",
"Epoch: 030/100 | Train: 31.162% | Loss: 0.014\n",
"Time elapsed: 57.37 min\n",
"Epoch: 031/100 | Train: 31.802% | Loss: 0.014\n",
"Time elapsed: 59.27 min\n",
"Epoch: 032/100 | Train: 31.421% | Loss: 0.014\n",
"Time elapsed: 61.28 min\n",
"Epoch: 033/100 | Train: 31.703% | Loss: 0.014\n",
"Time elapsed: 63.29 min\n",
"Epoch: 034/100 | Train: 31.067% | Loss: 0.014\n",
"Time elapsed: 65.37 min\n",
"Epoch: 035/100 | Train: 31.181% | Loss: 0.014\n",
"Time elapsed: 67.46 min\n",
"Epoch: 036/100 | Train: 31.653% | Loss: 0.014\n",
"Time elapsed: 69.43 min\n",
"Epoch: 037/100 | Train: 30.693% | Loss: 0.014\n",
"Time elapsed: 71.40 min\n",
"Epoch: 038/100 | Train: 31.307% | Loss: 0.014\n",
"Time elapsed: 73.42 min\n",
"Epoch: 039/100 | Train: 29.703% | Loss: 0.014\n",
"Time elapsed: 75.49 min\n",
"Epoch: 040/100 | Train: 32.419% | Loss: 0.014\n",
"Time elapsed: 77.51 min\n",
"Epoch: 041/100 | Train: 30.350% | Loss: 0.014\n",
"Time elapsed: 79.58 min\n",
"Epoch: 042/100 | Train: 32.019% | Loss: 0.014\n",
"Time elapsed: 81.64 min\n",
"Epoch: 043/100 | Train: 31.973% | Loss: 0.014\n",
"Time elapsed: 83.67 min\n",
"Epoch: 044/100 | Train: 32.541% | Loss: 0.014\n",
"Time elapsed: 85.66 min\n",
"Epoch: 045/100 | Train: 31.341% | Loss: 0.014\n",
"Time elapsed: 87.62 min\n",
"Epoch: 046/100 | Train: 31.756% | Loss: 0.014\n",
"Time elapsed: 89.55 min\n",
"Epoch: 047/100 | Train: 29.806% | Loss: 0.014\n",
"Time elapsed: 91.37 min\n",
"Epoch: 048/100 | Train: 32.099% | Loss: 0.014\n",
"Time elapsed: 93.20 min\n",
"Epoch: 049/100 | Train: 31.238% | Loss: 0.014\n",
"Time elapsed: 95.00 min\n",
"Epoch: 050/100 | Train: 31.208% | Loss: 0.014\n",
"Time elapsed: 96.84 min\n",
"Epoch: 051/100 | Train: 32.248% | Loss: 0.014\n",
"Time elapsed: 98.75 min\n",
"Epoch: 052/100 | Train: 30.526% | Loss: 0.014\n",
"Time elapsed: 100.69 min\n",
"Epoch: 053/100 | Train: 31.855% | Loss: 0.014\n",
"Time elapsed: 102.54 min\n",
"Epoch: 054/100 | Train: 32.301% | Loss: 0.014\n",
"Time elapsed: 104.42 min\n",
"Epoch: 055/100 | Train: 31.425% | Loss: 0.014\n",
"Time elapsed: 106.43 min\n",
"Epoch: 056/100 | Train: 32.000% | Loss: 0.014\n",
"Time elapsed: 108.49 min\n",
"Epoch: 057/100 | Train: 31.246% | Loss: 0.014\n",
"Time elapsed: 110.58 min\n",
"Epoch: 058/100 | Train: 31.177% | Loss: 0.014\n",
"Time elapsed: 112.64 min\n",
"Epoch: 059/100 | Train: 32.152% | Loss: 0.014\n",
"Time elapsed: 114.69 min\n",
"Epoch: 060/100 | Train: 30.103% | Loss: 0.014\n",
"Time elapsed: 116.84 min\n",
"Epoch: 061/100 | Train: 31.897% | Loss: 0.014\n",
"Time elapsed: 118.97 min\n",
"Epoch: 062/100 | Train: 30.796% | Loss: 0.014\n",
"Time elapsed: 120.96 min\n",
"Epoch: 063/100 | Train: 31.653% | Loss: 0.014\n",
"Time elapsed: 122.73 min\n",
"Epoch: 064/100 | Train: 31.501% | Loss: 0.014\n",
"Time elapsed: 124.55 min\n",
"Epoch: 065/100 | Train: 31.722% | Loss: 0.014\n",
"Time elapsed: 126.34 min\n",
"Epoch: 066/100 | Train: 31.448% | Loss: 0.014\n",
"Time elapsed: 128.16 min\n",
"Epoch: 067/100 | Train: 31.349% | Loss: 0.014\n",
"Time elapsed: 129.94 min\n",
"Epoch: 068/100 | Train: 30.453% | Loss: 0.014\n",
"Time elapsed: 131.75 min\n",
"Epoch: 069/100 | Train: 31.238% | Loss: 0.014\n",
"Time elapsed: 133.54 min\n",
"Epoch: 070/100 | Train: 32.095% | Loss: 0.014\n",
"Time elapsed: 135.32 min\n",
"Epoch: 071/100 | Train: 30.949% | Loss: 0.014\n",
"Time elapsed: 137.11 min\n",
"Epoch: 072/100 | Train: 32.488% | Loss: 0.014\n",
"Time elapsed: 138.88 min\n",
"Epoch: 073/100 | Train: 31.718% | Loss: 0.014\n",
"Time elapsed: 140.64 min\n",
"Epoch: 074/100 | Train: 31.806% | Loss: 0.014\n",
"Time elapsed: 142.44 min\n",
"Epoch: 075/100 | Train: 31.478% | Loss: 0.014\n",
"Time elapsed: 144.22 min\n",
"Epoch: 076/100 | Train: 31.192% | Loss: 0.014\n",
"Time elapsed: 146.02 min\n",
"Epoch: 077/100 | Train: 32.099% | Loss: 0.014\n",
"Time elapsed: 147.80 min\n",
"Epoch: 078/100 | Train: 32.549% | Loss: 0.014\n",
"Time elapsed: 149.58 min\n",
"Epoch: 079/100 | Train: 30.777% | Loss: 0.014\n",
"Time elapsed: 151.41 min\n",
"Epoch: 080/100 | Train: 32.149% | Loss: 0.014\n",
"Time elapsed: 153.16 min\n",
"Epoch: 081/100 | Train: 32.331% | Loss: 0.013\n",
"Time elapsed: 154.93 min\n",
"Epoch: 082/100 | Train: 32.392% | Loss: 0.014\n",
"Time elapsed: 156.72 min\n",
"Epoch: 083/100 | Train: 32.571% | Loss: 0.014\n",
"Time elapsed: 158.50 min\n",
"Epoch: 084/100 | Train: 31.794% | Loss: 0.014\n",
"Time elapsed: 160.31 min\n",
"Epoch: 085/100 | Train: 32.213% | Loss: 0.014\n",
"Time elapsed: 162.08 min\n",
"Epoch: 086/100 | Train: 31.112% | Loss: 0.014\n",
"Time elapsed: 163.87 min\n",
"Epoch: 087/100 | Train: 30.842% | Loss: 0.014\n",
"Time elapsed: 165.65 min\n",
"Epoch: 088/100 | Train: 30.850% | Loss: 0.014\n",
"Time elapsed: 167.42 min\n",
"Epoch: 089/100 | Train: 31.977% | Loss: 0.014\n",
"Time elapsed: 169.22 min\n",
"Epoch: 090/100 | Train: 31.337% | Loss: 0.014\n",
"Time elapsed: 171.01 min\n",
"Epoch: 091/100 | Train: 32.019% | Loss: 0.014\n",
"Time elapsed: 172.81 min\n",
"Epoch: 092/100 | Train: 31.768% | Loss: 0.014\n",
"Time elapsed: 174.60 min\n",
"Epoch: 093/100 | Train: 32.415% | Loss: 0.013\n",
"Time elapsed: 176.37 min\n",
"Epoch: 094/100 | Train: 31.730% | Loss: 0.014\n",
"Time elapsed: 178.17 min\n",
"Epoch: 095/100 | Train: 32.038% | Loss: 0.014\n",
"Time elapsed: 179.99 min\n",
"Epoch: 096/100 | Train: 32.465% | Loss: 0.013\n",
"Time elapsed: 181.79 min\n",
"Epoch: 097/100 | Train: 31.242% | Loss: 0.014\n",
"Time elapsed: 183.56 min\n",
"Epoch: 098/100 | Train: 31.634% | Loss: 0.014\n",
"Time elapsed: 185.35 min\n",
"Epoch: 099/100 | Train: 32.301% | Loss: 0.013\n",
"Time elapsed: 187.13 min\n",
"Epoch: 100/100 | Train: 31.787% | Loss: 0.014\n",
"Time elapsed: 188.91 min\n",
"Total Training Time: 188.91 min\n"
]
}
],
"source": [
"num_epochs = 100\n",
"import time \n",
"\n",
"start_time = time.time()\n",
"for epoch in range(num_epochs):\n",
" \n",
" model.train()\n",
" for batch_idx, (features, targets) in enumerate(train_dataloader):\n",
" \n",
" features = features.to(device)\n",
"# print(features.shape)\n",
"# break\n",
" targets = targets.to(device)\n",
" \n",
" ### FORWARD AND BACK PROP\n",
"# logits, probas = model(features)\n",
"# cost = F.cross_entropy(logits, targets)\n",
"\n",
" predicted = model(features)\n",
" loss = criterion(predicted,targets)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"# cost.backward()\n",
" \n",
" ### UPDATE MODEL PARAMETERS\n",
" optimizer.step()\n",
" \n",
" ### LOGGING\n",
"# if not batch_idx % 50:\n",
"# print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
"# %(epoch+1, num_epochs, batch_idx, \n",
"# len(train_dataloader), cost))\n",
"\n",
" model.eval()\n",
" with torch.set_grad_enabled(False): # save memory during inference\n",
" print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n",
" epoch+1, num_epochs, \n",
" compute_accuracy(model, train_dataloader),\n",
" compute_epoch_loss(model, train_dataloader)))\n",
"\n",
"\n",
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
" \n",
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), './checkpoints/weight18-300.pth')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-27-421d3408d711>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Training loss'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_losses\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Validation loss'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlegend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mframeon\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mTypeError\u001b[0m: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first."
]
}
],
"source": [
"b = predicted.numpy()\n",
"plt.plot(b, label='Training loss')\n",
"plt.plot(test_losses, label='Validation loss')\n",
"plt.legend(frameon=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment