Skip to content

Instantly share code, notes, and snippets.

@aakashns
Created July 15, 2018 17:31
Show Gist options
  • Save aakashns/90c13a903ff510c5baa72293fea72952 to your computer and use it in GitHub Desktop.
Save aakashns/90c13a903ff510c5baa72293fea72952 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the Data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torchvision.transforms as tt\n",
"from torchvision.datasets import ImageFolder\n",
"from torch.utils.data import DataLoader\n",
"from fastai.dataset import ModelData\n",
"\n",
"def get_data(bs, num_workers):\n",
" PATH = \"data/cifar10/\"\n",
" trn_dir, val_dir = PATH + 'train', PATH + 'test'\n",
" stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
" \n",
" # Data transforms (normalization & data augmentation)\n",
" tfms = [tt.ToTensor(), tt.Normalize(*stats)]\n",
" aug_tfms = tt.Compose([tt.RandomCrop(32, padding=4), \n",
" tt.RandomHorizontalFlip()] + tfms)\n",
" # PyTorch datasets\n",
" trn_ds = ImageFolder(trn_dir, aug_tfms)\n",
" val_ds = ImageFolder(val_dir, tt.Compose(tfms))\n",
" aug_ds = ImageFolder(val_dir, aug_tfms)\n",
" \n",
" # PyTorch data loaders\n",
" trn_dl = DataLoader(trn_ds, batch_size=bs, shuffle=True, \n",
" num_workers=num_workers, pin_memory=True)\n",
" val_dl = DataLoader(val_ds, batch_size=bs, shuffle=False, \n",
" num_workers=num_workers, pin_memory=True)\n",
" aug_dl = DataLoader(aug_ds, batch_size=bs, shuffle=False, \n",
" num_workers=num_workers, pin_memory=True)\n",
" \n",
" # FastAI model data \n",
" data = ModelData(PATH, trn_dl, val_dl)\n",
" data.aug_dl = aug_dl\n",
" data.sz = 32\n",
" \n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_learner(arch, bs):\n",
" \"\"\"Create a FastAI learner using the given model\"\"\"\n",
" data = get_data(bs, num_cpus())\n",
" learn = ConvLearner.from_model_data(arch.cuda(), data)\n",
" learn.crit = nn.CrossEntropyLoss()\n",
" learn.metrics = [accuracy]\n",
" return learn\n",
"\n",
"def get_TTA_accuracy(learn):\n",
" \"\"\"Calculate accuracy with Test Time Agumentation(TTA)\"\"\"\n",
" preds, targs = learn.TTA()\n",
" preds = 0.6 * preds[0] + 0.4 * preds[1:].sum(0)\n",
" return accuracy_np(preds, targs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the network"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"def conv_2d(ni, nf, stride=1, ks=3):\n",
" \"\"\"3x3 convolution with 1 pixel padding\"\"\"\n",
" return nn.Conv2d(in_channels=ni, out_channels=nf, \n",
" kernel_size=ks, stride=stride, \n",
" padding=ks//2, bias=False)\n",
"\n",
"def bn_relu_conv(ni, nf):\n",
" \"\"\"BatchNorm → ReLU → Conv2D\"\"\"\n",
" return nn.Sequential(nn.BatchNorm2d(ni), \n",
" nn.ReLU(inplace=True), \n",
" conv_2d(ni, nf))\n",
"\n",
"class BasicBlock(nn.Module):\n",
" \"\"\"Residual block with shortcut connection\"\"\"\n",
" def __init__(self, ni, nf, stride=1):\n",
" super().__init__()\n",
" self.bn = nn.BatchNorm2d(ni)\n",
" self.conv1 = conv_2d(ni, nf, stride)\n",
" self.conv2 = bn_relu_conv(nf, nf)\n",
" self.shortcut = lambda x: x\n",
" if ni != nf:\n",
" self.shortcut = conv_2d(ni, nf, stride, 1)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.bn(x), inplace=True)\n",
" r = self.shortcut(x)\n",
" x = self.conv1(x)\n",
" x = self.conv2(x) * 0.2\n",
" return x.add_(r)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def make_group(N, ni, nf, stride):\n",
" \"\"\"Group of residual blocks\"\"\"\n",
" start = BasicBlock(ni, nf, stride)\n",
" rest = [BasicBlock(nf, nf) for j in range(1, N)]\n",
" return [start] + rest\n",
"\n",
"class Flatten(nn.Module):\n",
" def __init__(self): super().__init__()\n",
" def forward(self, x): return x.view(x.size(0), -1)\n",
"\n",
"class WideResNet(nn.Module):\n",
" def __init__(self, n_groups, N, n_classes, k=1, n_start=16):\n",
" super().__init__() \n",
" # Increase channels to n_start using conv layer\n",
" layers = [conv_2d(3, n_start)]\n",
" n_channels = [n_start]\n",
" \n",
" # Add groups of BasicBlock(increase channels & downsample)\n",
" for i in range(n_groups):\n",
" n_channels.append(n_start*(2**i)*k)\n",
" stride = 2 if i>0 else 1\n",
" layers += make_group(N, n_channels[i], \n",
" n_channels[i+1], stride)\n",
" \n",
" # Pool, flatten & add linear layer for classification\n",
" layers += [nn.BatchNorm2d(n_channels[3]), \n",
" nn.ReLU(inplace=True), \n",
" nn.AdaptiveAvgPool2d(1), \n",
" Flatten(), \n",
" nn.Linear(n_channels[3], n_classes)]\n",
" \n",
" self.features = nn.Sequential(*layers)\n",
" \n",
" def forward(self, x): return self.features(x)\n",
" \n",
"def wrn_22(): \n",
" return WideResNet(n_groups=3, N=3, n_classes=10, k=6)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training & Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6f9770be1824b62bd7a8c3d895db1cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 1.307771 1.355958 0.5027 \n",
" 1 0.973831 1.268146 0.5804 \n",
" 2 0.815618 0.937404 0.6821 \n",
" 3 0.726471 0.901928 0.7004 \n",
" 4 0.654479 0.777541 0.7319 \n",
" 5 0.630079 0.783178 0.7379 \n",
" 6 0.614516 0.817597 0.7293 \n",
" 7 0.606512 0.749424 0.7461 \n",
" 8 0.587174 1.035898 0.6526 \n",
" 9 0.575562 1.696366 0.5554 \n",
" 10 0.566359 0.798111 0.7341 \n",
" 11 0.545117 0.70227 0.7569 \n",
" 12 0.499315 0.611959 0.7922 \n",
" 13 0.469588 0.717421 0.767 \n",
" 14 0.437617 0.695363 0.7639 \n",
" 15 0.401804 0.489137 0.8375 \n",
" 16 0.316073 0.347868 0.8784 \n",
" 17 0.246093 0.283443 0.9038 \n",
" 18 0.198445 0.247639 0.9156 \n",
" 19 0.149643 0.219992 0.9242 \n",
"\n",
"CPU times: user 15min 20s, sys: 7min 21s, total: 22min 42s\n",
"Wall time: 22min 27s\n"
]
}
],
"source": [
"%%time\n",
"learn = get_learner(wrn_22(), 128)\n",
"learn.clip = 1e-1\n",
"learn.fit(1.5, 1, wds=1e-4, cycle_len=20, use_clr_beta=(12, 15, 0.95, 0.85))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \r"
]
},
{
"data": {
"text/plain": [
"0.9287"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_TTA_accuracy(learn)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.sched.plot_loss()\n",
"learn.sched.plot_lr()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@sidujjain
Copy link

sidujjain commented Jan 23, 2019

Great stuff!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment