Skip to content

Instantly share code, notes, and snippets.

@naiveHobo
Created August 24, 2018 19:58
Show Gist options
  • Save naiveHobo/225d9fea117fb083ae9e7f8e5336881b to your computer and use it in GitHub Desktop.
Save naiveHobo/225d9fea117fb083ae9e7f8e5336881b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from random import shuffle\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"import torch.utils.data as data\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import torchvision.models as models\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using GPU\n"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"if torch.cuda.is_available():\n",
" use_gpu = True\n",
" print(\"Using GPU\")\n",
"else:\n",
" use_gpu = False\n",
"FloatTensor = torch.cuda.FloatTensor if use_gpu else torch.FloatTensor\n",
"LongTensor = torch.cuda.LongTensor if use_gpu else torch.LongTensor\n",
"ByteTensor = torch.cuda.ByteTensor if use_gpu else torch.ByteTensor\n",
"Tensor = FloatTensor"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class ImageDataset(data.Dataset):\n",
" \n",
" def __init__(self, image_paths, class_dict, transform=None):\n",
" self.image_paths = image_paths\n",
" self.transform = transform\n",
" self.class_dict = class_dict\n",
" \n",
" def __len__(self):\n",
" return len(self.image_paths)\n",
" \n",
" def __getitem__(self, index):\n",
" img_path = self.image_paths[index]\n",
" image = Image.open(img_path)\n",
" if self.transform:\n",
" image = self.transform(image)\n",
" for key, value in class_dict.items():\n",
" if key in img_path:\n",
" label = value\n",
" break\n",
" label = torch.tensor(label)\n",
" return [image, label]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def get_resnet_model(num_classes):\n",
" resnet = models.resnet18(pretrained='imagenet')\n",
" num_hidden = resnet.fc.in_features\n",
" resnet.fc = nn.Linear(num_hidden, num_classes)\n",
" for name, params in resnet.named_parameters():\n",
" if 'layer3' not in name:\n",
" params.requires_grad = False\n",
" else:\n",
" break\n",
" return resnet"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs=25):\n",
" best_model_wts = model.state_dict()\n",
" best_acc = 0.0\n",
"\n",
" for epoch in range(num_epochs):\n",
" print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
" print('-' * 10)\n",
"\n",
" for phase in ['train', 'val']:\n",
" \n",
" if phase == 'train':\n",
" scheduler.step()\n",
" model.train(True)\n",
" else:\n",
" model.train(False)\n",
"\n",
" running_loss = 0.0\n",
" running_corrects = 0.0\n",
"\n",
" for data in tqdm(dataloaders[phase]):\n",
" inputs, labels = data\n",
"\n",
" inputs = Variable(inputs.type(Tensor))\n",
" labels = Variable(labels.type(LongTensor))\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = model(inputs)\n",
" if type(outputs) == tuple:\n",
" outputs, _ = outputs\n",
" _, preds = torch.max(outputs.data, 1)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" if phase == 'train':\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" batch_loss = loss.data[0]\n",
" batch_acc = torch.sum(preds == labels.data)\n",
" \n",
" running_loss += batch_loss\n",
" running_corrects += batch_acc\n",
"\n",
" epoch_loss = running_loss / dataset_sizes[phase]\n",
" epoch_acc = float(running_corrects) / dataset_sizes[phase]\n",
"\n",
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
" phase, epoch_loss, epoch_acc))\n",
"\n",
" if phase == 'val' and epoch_acc > best_acc:\n",
" best_acc = epoch_acc\n",
" best_model_wts = model.state_dict()\n",
"\n",
" print()\n",
"\n",
" print('Best validation accuarcy: {:4f}'.format(best_acc))\n",
"\n",
" model.load_state_dict(best_model_wts)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"data_dir = './natural_images/'\n",
"mean = [0.485, 0.456, 0.406]\n",
"std = [0.229, 0.224, 0.225]\n",
"scale = 360\n",
"input_shape = 224\n",
"batch_size = 128\n",
"epochs = 20\n",
"num_classes = 8"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"img_files = []\n",
"classes = os.listdir(data_dir)\n",
"class_dict = {label: i for i, label in enumerate(classes)}\n",
"\n",
"for directory in os.listdir(data_dir):\n",
" for filename in os.listdir(os.path.join(data_dir, directory)):\n",
" img_files.append(os.path.join(data_dir, directory, filename))\n",
"\n",
"shuffle(img_files)\n",
"\n",
"train_length = int(0.6*len(img_files))\n",
"val_length = int(0.8*len(img_files))\n",
"img_files = {'train': img_files[:train_length],\n",
" 'val': img_files[train_length:val_length],\n",
" 'test': img_files[val_length:]}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data_transforms = {\n",
" 'train': transforms.Compose([\n",
" transforms.Resize(scale),\n",
" transforms.RandomResizedCrop(input_shape),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.RandomVerticalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)]),\n",
" \n",
" 'val': transforms.Compose([\n",
" transforms.Resize(scale),\n",
" transforms.CenterCrop(input_shape),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)])\n",
"}\n",
"\n",
"\n",
"image_datasets = {x: ImageDataset(img_files[x], class_dict, data_transforms[x]) for x in ['train', 'val']}\n",
"\n",
"dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,\n",
" shuffle=True, num_workers=4) for x in ['train', 'val']}\n",
"\n",
"dataset_sizes = {x: float(len(image_datasets[x])) for x in ['train', 'val']}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:38: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
"100%|██████████| 33/33 [00:16<00:00, 1.95it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0098 Acc: 0.6564\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.83it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0032 Acc: 0.9225\n",
"\n",
"Epoch 1/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.94it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0024 Acc: 0.9394\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.69it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0014 Acc: 0.9659\n",
"\n",
"Epoch 2/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.93it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0015 Acc: 0.9524\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.80it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0010 Acc: 0.9725\n",
"\n",
"Epoch 3/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.89it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0012 Acc: 0.9633\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.78it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0009 Acc: 0.9761\n",
"\n",
"Epoch 4/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.86it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0010 Acc: 0.9667\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.73it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0008 Acc: 0.9790\n",
"\n",
"Epoch 5/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.92it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0010 Acc: 0.9642\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.74it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0007 Acc: 0.9812\n",
"\n",
"Epoch 6/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9681\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.25it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0007 Acc: 0.9790\n",
"\n",
"Epoch 7/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.91it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0009 Acc: 0.9705\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.81it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9812\n",
"\n",
"Epoch 8/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9710\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.35it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9812\n",
"\n",
"Epoch 9/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9686\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.75it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9804\n",
"\n",
"Epoch 10/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9751\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.60it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9790\n",
"\n",
"Epoch 11/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9737\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.75it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9790\n",
"\n",
"Epoch 12/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9717\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.62it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9790\n",
"\n",
"Epoch 13/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9667\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.39it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9797\n",
"\n",
"Epoch 14/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9732\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.74it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9812\n",
"\n",
"Epoch 15/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.87it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9729\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:03<00:00, 2.80it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9797\n",
"\n",
"Epoch 16/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.88it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9715\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.72it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9790\n",
"\n",
"Epoch 17/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.86it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0008 Acc: 0.9725\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.71it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9797\n",
"\n",
"Epoch 18/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.89it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0007 Acc: 0.9720\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.50it/s]\n",
" 0%| | 0/33 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9790\n",
"\n",
"Epoch 19/19\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:17<00:00, 1.85it/s]\n",
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train Loss: 0.0007 Acc: 0.9722\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.66it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val Loss: 0.0006 Acc: 0.9797\n",
"\n",
"Best validation accuarcy: 0.981159\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model = get_resnet_model(num_classes).to(device)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(list(filter(lambda p: p.requires_grad, model.parameters())), lr=0.001, momentum=0.9)\n",
"exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)\n",
"\n",
"model = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, num_epochs=epochs)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"test_transform = transforms.Compose([\n",
" transforms.Resize(scale),\n",
" transforms.CenterCrop(input_shape),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)])\n",
"\n",
"test_set = ImageDataset(img_files['test'], class_dict, test_transform)\n",
"\n",
"test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,\n",
" shuffle=True, num_workers=4)\n",
"\n",
"test_size = float(len(test_set))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:04<00:00, 2.59it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 0.9855072463768116\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model.train(False)\n",
"\n",
"running_loss = 0.0\n",
"running_corrects = 0.0\n",
"\n",
"for data in tqdm(test_loader):\n",
" inputs, labels = data\n",
"\n",
" inputs = Variable(inputs.type(Tensor))\n",
" labels = Variable(labels.type(LongTensor))\n",
"\n",
" outputs = model(inputs)\n",
" if type(outputs) == tuple:\n",
" outputs, _ = outputs\n",
" _, preds = torch.max(outputs.data, 1)\n",
" \n",
" running_corrects += torch.sum(preds == labels.data)\n",
"\n",
"accuracy = float(running_corrects) / test_size\n",
"print(\"Test Accuracy: {}\".format(accuracy))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment