Created
August 24, 2018 19:58
-
-
Save naiveHobo/225d9fea117fb083ae9e7f8e5336881b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"from random import shuffle\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.autograd import Variable\n", | |
"import torch.utils.data as data\n", | |
"import torch.optim as optim\n", | |
"from torch.optim import lr_scheduler\n", | |
"import torchvision.models as models\n", | |
"from torchvision import transforms\n", | |
"from PIL import Image\n", | |
"from tqdm import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using GPU\n" | |
] | |
} | |
], | |
"source": [ | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"\n", | |
"if torch.cuda.is_available():\n", | |
" use_gpu = True\n", | |
" print(\"Using GPU\")\n", | |
"else:\n", | |
" use_gpu = False\n", | |
"FloatTensor = torch.cuda.FloatTensor if use_gpu else torch.FloatTensor\n", | |
"LongTensor = torch.cuda.LongTensor if use_gpu else torch.LongTensor\n", | |
"ByteTensor = torch.cuda.ByteTensor if use_gpu else torch.ByteTensor\n", | |
"Tensor = FloatTensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ImageDataset(data.Dataset):\n", | |
" \n", | |
" def __init__(self, image_paths, class_dict, transform=None):\n", | |
" self.image_paths = image_paths\n", | |
" self.transform = transform\n", | |
" self.class_dict = class_dict\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.image_paths)\n", | |
" \n", | |
" def __getitem__(self, index):\n", | |
" img_path = self.image_paths[index]\n", | |
" image = Image.open(img_path)\n", | |
" if self.transform:\n", | |
" image = self.transform(image)\n", | |
" for key, value in class_dict.items():\n", | |
" if key in img_path:\n", | |
" label = value\n", | |
" break\n", | |
" label = torch.tensor(label)\n", | |
" return [image, label]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_resnet_model(num_classes):\n", | |
" resnet = models.resnet18(pretrained='imagenet')\n", | |
" num_hidden = resnet.fc.in_features\n", | |
" resnet.fc = nn.Linear(num_hidden, num_classes)\n", | |
" for name, params in resnet.named_parameters():\n", | |
" if 'layer3' not in name:\n", | |
" params.requires_grad = False\n", | |
" else:\n", | |
" break\n", | |
" return resnet" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs=25):\n", | |
" best_model_wts = model.state_dict()\n", | |
" best_acc = 0.0\n", | |
"\n", | |
" for epoch in range(num_epochs):\n", | |
" print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", | |
" print('-' * 10)\n", | |
"\n", | |
" for phase in ['train', 'val']:\n", | |
" \n", | |
" if phase == 'train':\n", | |
" scheduler.step()\n", | |
" model.train(True)\n", | |
" else:\n", | |
" model.train(False)\n", | |
"\n", | |
" running_loss = 0.0\n", | |
" running_corrects = 0.0\n", | |
"\n", | |
" for data in tqdm(dataloaders[phase]):\n", | |
" inputs, labels = data\n", | |
"\n", | |
" inputs = Variable(inputs.type(Tensor))\n", | |
" labels = Variable(labels.type(LongTensor))\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" outputs = model(inputs)\n", | |
" if type(outputs) == tuple:\n", | |
" outputs, _ = outputs\n", | |
" _, preds = torch.max(outputs.data, 1)\n", | |
" loss = criterion(outputs, labels)\n", | |
"\n", | |
" if phase == 'train':\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" batch_loss = loss.data[0]\n", | |
" batch_acc = torch.sum(preds == labels.data)\n", | |
" \n", | |
" running_loss += batch_loss\n", | |
" running_corrects += batch_acc\n", | |
"\n", | |
" epoch_loss = running_loss / dataset_sizes[phase]\n", | |
" epoch_acc = float(running_corrects) / dataset_sizes[phase]\n", | |
"\n", | |
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n", | |
" phase, epoch_loss, epoch_acc))\n", | |
"\n", | |
" if phase == 'val' and epoch_acc > best_acc:\n", | |
" best_acc = epoch_acc\n", | |
" best_model_wts = model.state_dict()\n", | |
"\n", | |
" print()\n", | |
"\n", | |
" print('Best validation accuarcy: {:4f}'.format(best_acc))\n", | |
"\n", | |
" model.load_state_dict(best_model_wts)\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_dir = './natural_images/'\n", | |
"mean = [0.485, 0.456, 0.406]\n", | |
"std = [0.229, 0.224, 0.225]\n", | |
"scale = 360\n", | |
"input_shape = 224\n", | |
"batch_size = 128\n", | |
"epochs = 20\n", | |
"num_classes = 8" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"img_files = []\n", | |
"classes = os.listdir(data_dir)\n", | |
"class_dict = {label: i for i, label in enumerate(classes)}\n", | |
"\n", | |
"for directory in os.listdir(data_dir):\n", | |
" for filename in os.listdir(os.path.join(data_dir, directory)):\n", | |
" img_files.append(os.path.join(data_dir, directory, filename))\n", | |
"\n", | |
"shuffle(img_files)\n", | |
"\n", | |
"train_length = int(0.6*len(img_files))\n", | |
"val_length = int(0.8*len(img_files))\n", | |
"img_files = {'train': img_files[:train_length],\n", | |
" 'val': img_files[train_length:val_length],\n", | |
" 'test': img_files[val_length:]}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_transforms = {\n", | |
" 'train': transforms.Compose([\n", | |
" transforms.Resize(scale),\n", | |
" transforms.RandomResizedCrop(input_shape),\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.RandomVerticalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean, std)]),\n", | |
" \n", | |
" 'val': transforms.Compose([\n", | |
" transforms.Resize(scale),\n", | |
" transforms.CenterCrop(input_shape),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean, std)])\n", | |
"}\n", | |
"\n", | |
"\n", | |
"image_datasets = {x: ImageDataset(img_files[x], class_dict, data_transforms[x]) for x in ['train', 'val']}\n", | |
"\n", | |
"dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,\n", | |
" shuffle=True, num_workers=4) for x in ['train', 'val']}\n", | |
"\n", | |
"dataset_sizes = {x: float(len(image_datasets[x])) for x in ['train', 'val']}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:38: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n", | |
"100%|██████████| 33/33 [00:16<00:00, 1.95it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0098 Acc: 0.6564\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.83it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0032 Acc: 0.9225\n", | |
"\n", | |
"Epoch 1/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.94it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0024 Acc: 0.9394\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.69it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0014 Acc: 0.9659\n", | |
"\n", | |
"Epoch 2/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.93it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0015 Acc: 0.9524\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.80it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0010 Acc: 0.9725\n", | |
"\n", | |
"Epoch 3/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.89it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0012 Acc: 0.9633\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.78it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0009 Acc: 0.9761\n", | |
"\n", | |
"Epoch 4/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.86it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0010 Acc: 0.9667\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.73it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0008 Acc: 0.9790\n", | |
"\n", | |
"Epoch 5/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.92it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0010 Acc: 0.9642\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.74it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0007 Acc: 0.9812\n", | |
"\n", | |
"Epoch 6/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9681\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.25it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0007 Acc: 0.9790\n", | |
"\n", | |
"Epoch 7/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.91it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0009 Acc: 0.9705\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.81it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9812\n", | |
"\n", | |
"Epoch 8/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9710\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.35it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9812\n", | |
"\n", | |
"Epoch 9/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9686\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.75it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9804\n", | |
"\n", | |
"Epoch 10/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9751\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.60it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9790\n", | |
"\n", | |
"Epoch 11/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9737\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.75it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9790\n", | |
"\n", | |
"Epoch 12/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9717\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.62it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9790\n", | |
"\n", | |
"Epoch 13/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9667\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.39it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9797\n", | |
"\n", | |
"Epoch 14/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9732\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.74it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9812\n", | |
"\n", | |
"Epoch 15/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9729\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:03<00:00, 2.80it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9797\n", | |
"\n", | |
"Epoch 16/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9715\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.72it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9790\n", | |
"\n", | |
"Epoch 17/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.86it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0008 Acc: 0.9725\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.71it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9797\n", | |
"\n", | |
"Epoch 18/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.89it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0007 Acc: 0.9720\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.50it/s]\n", | |
" 0%| | 0/33 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9790\n", | |
"\n", | |
"Epoch 19/19\n", | |
"----------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n", | |
" 0%| | 0/11 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train Loss: 0.0007 Acc: 0.9722\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.66it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"val Loss: 0.0006 Acc: 0.9797\n", | |
"\n", | |
"Best validation accuarcy: 0.981159\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"model = get_resnet_model(num_classes).to(device)\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.SGD(list(filter(lambda p: p.requires_grad, model.parameters())), lr=0.001, momentum=0.9)\n", | |
"exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)\n", | |
"\n", | |
"model = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, num_epochs=epochs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_transform = transforms.Compose([\n", | |
" transforms.Resize(scale),\n", | |
" transforms.CenterCrop(input_shape),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean, std)])\n", | |
"\n", | |
"test_set = ImageDataset(img_files['test'], class_dict, test_transform)\n", | |
"\n", | |
"test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,\n", | |
" shuffle=True, num_workers=4)\n", | |
"\n", | |
"test_size = float(len(test_set))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 11/11 [00:04<00:00, 2.59it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test Accuracy: 0.9855072463768116\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"model.train(False)\n", | |
"\n", | |
"running_loss = 0.0\n", | |
"running_corrects = 0.0\n", | |
"\n", | |
"for data in tqdm(test_loader):\n", | |
" inputs, labels = data\n", | |
"\n", | |
" inputs = Variable(inputs.type(Tensor))\n", | |
" labels = Variable(labels.type(LongTensor))\n", | |
"\n", | |
" outputs = model(inputs)\n", | |
" if type(outputs) == tuple:\n", | |
" outputs, _ = outputs\n", | |
" _, preds = torch.max(outputs.data, 1)\n", | |
" \n", | |
" running_corrects += torch.sum(preds == labels.data)\n", | |
"\n", | |
"accuracy = float(running_corrects) / test_size\n", | |
"print(\"Test Accuracy: {}\".format(accuracy))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment