Skip to content

Instantly share code, notes, and snippets.

@wdhorton
Last active July 31, 2019 02:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save wdhorton/447f9aecc209ee6fc6ddab3122f6b685 to your computer and use it in GitHub Desktop.
Save wdhorton/447f9aecc209ee6fc6ddab3122f6b685 to your computer and use it in GitHub Desktop.
SWA testing
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"\n",
"'''Resnet for cifar dataset.\n",
"Ported form\n",
"https://github.com/facebook/fb.resnet.torch\n",
"and\n",
"https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n",
"(c) YANG, Wei\n",
"'''\n",
"import torch.nn as nn\n",
"import math\n",
"\n",
"\n",
"__all__ = ['preresnet']\n",
"\n",
"def conv3x3(in_planes, out_planes, stride=1):\n",
" \"3x3 convolution with padding\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
"\n",
"\n",
"class BasicBlock(nn.Module):\n",
" expansion = 1\n",
"\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(BasicBlock, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = conv3x3(planes, planes)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
"\n",
" def forward(self, x):\n",
" residual = x\n",
"\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
"\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
"\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
"\n",
" out += residual\n",
"\n",
" return out\n",
"\n",
"\n",
"class Bottleneck(nn.Module):\n",
" expansion = 4\n",
"\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(Bottleneck, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
" self.bn3 = nn.BatchNorm2d(planes)\n",
" self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
"\n",
" def forward(self, x):\n",
" residual = x\n",
"\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
"\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
"\n",
" out = self.bn3(out)\n",
" out = self.relu(out)\n",
" out = self.conv3(out)\n",
"\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
"\n",
" out += residual\n",
"\n",
" return out\n",
"\n",
"\n",
"class PreResNet(nn.Module):\n",
"\n",
" def __init__(self, depth, num_classes=1000):\n",
" super(PreResNet, self).__init__()\n",
" # Model type specifies number of layers for CIFAR-10 model\n",
" assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n",
" n = (depth - 2) // 6\n",
"\n",
" block = Bottleneck if depth >=44 else BasicBlock\n",
"\n",
" self.inplanes = 16\n",
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n",
" bias=False)\n",
" self.layer1 = self._make_layer(block, 16, n)\n",
" self.layer2 = self._make_layer(block, 32, n, stride=2)\n",
" self.layer3 = self._make_layer(block, 64, n, stride=2)\n",
" self.bn = nn.BatchNorm2d(64 * block.expansion)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.avgpool = nn.AvgPool2d(8)\n",
" self.fc = nn.Linear(64 * block.expansion, num_classes)\n",
"\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
" m.weight.data.normal_(0, math.sqrt(2. / n))\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" m.weight.data.fill_(1)\n",
" m.bias.data.zero_()\n",
"\n",
" def _make_layer(self, block, planes, blocks, stride=1):\n",
" downsample = None\n",
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
" downsample = nn.Sequential(\n",
" nn.Conv2d(self.inplanes, planes * block.expansion,\n",
" kernel_size=1, stride=stride, bias=False),\n",
" )\n",
"\n",
" layers = []\n",
" layers.append(block(self.inplanes, planes, stride, downsample))\n",
" self.inplanes = planes * block.expansion\n",
" for i in range(1, blocks):\n",
" layers.append(block(self.inplanes, planes))\n",
"\n",
" return nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
"\n",
" x = self.layer1(x) # 32x32\n",
" x = self.layer2(x) # 16x16\n",
" x = self.layer3(x) # 8x8\n",
" x = self.bn(x)\n",
" x = self.relu(x)\n",
"\n",
" x = self.avgpool(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc(x)\n",
"\n",
" return x\n",
"\n",
"\n",
"def preresnet(**kwargs):\n",
" \"\"\"\n",
" Constructs a ResNet model.\n",
" \"\"\"\n",
" return PreResNet(**kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from fastai.conv_learner import *\n",
"from fastai.model import fit\n",
"from fastai.core import SGD_Momentum\n",
"\n",
"PATH = \"data/cifar10/\"\n",
"os.makedirs(PATH,exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_data(sz,bs):\n",
" tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
" return ImageClassifierData.from_paths(PATH, trn_name='train_', val_name='test_', tfms=tfms, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"bs=128"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,bs)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"lr=0.1"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import Callback\n",
"\n",
"class SgdLrUpdater(Callback):\n",
" def __init__(self, layer_opt, init_lr, budget):\n",
" self.layer_opt=layer_opt\n",
" self.init_lr=init_lr\n",
" self.budget=budget\n",
" \n",
" def on_train_begin(self):\n",
" self.epoch = 0\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" self.epoch += 1\n",
" self.update_lr()\n",
" \n",
" def update_lr(self): \n",
" new_lr = self.calc_lr()\n",
" self.layer_opt.set_lrs([new_lr])\n",
" \n",
" def calc_lr(self):\n",
" if self.epoch < self.budget//2:\n",
" return self.init_lr\n",
" elif self.epoch > 0.9 * self.budget:\n",
" return 0.01 * self.init_lr\n",
" else:\n",
" return self.init_lr - (self.init_lr * 0.99 / int(0.4 * self.budget) * (self.epoch - self.budget//2))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import LoggingCallback"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5be0b56e68c5499091096a160172daa3",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=150), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 1.321136 1.399471 0.486847 \n",
" 1 1.042043 1.044005 0.632516 \n",
" 2 0.871521 0.895404 0.688588 \n",
" 3 0.743716 0.824015 0.723398 \n",
" 4 0.679371 0.934189 0.70085 \n",
" 5 0.606691 0.70843 0.767405 \n",
" 6 0.574801 0.637944 0.782437 \n",
" 7 0.550256 0.573916 0.802116 \n",
" 8 0.525207 0.655205 0.77769 \n",
" 9 0.518154 0.669004 0.772053 \n",
" 10 0.512644 0.684118 0.777393 \n",
" 11 0.488844 0.550262 0.810324 \n",
" 2%|▏ | 7/391 [00:08<08:13, 1.28s/it, loss=0.483]"
]
}
],
"source": [
"wd = 3e-4\n",
"budget = 150\n",
"\n",
"# train 3 preresnet110 models with normal SGD with momentum\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
" learn.data, \n",
" layer_opt, \n",
" budget, \n",
" callbacks=[SgdLrUpdater(layer_opt, lr, budget), LoggingCallback(f'{PATH}logs/sgd_{i}.txt')]\n",
" )\n",
" learn.save(f'sgd_{i}')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import Callback\n",
"\n",
"class SwaLrUpdater(Callback):\n",
" def __init__(self, layer_opt, init_lr, budget, swa_start, swa_lr):\n",
" self.layer_opt=layer_opt\n",
" self.init_lr=init_lr\n",
" self.budget=budget\n",
" self.swa_start=swa_start\n",
" self.swa_lr=swa_lr\n",
" \n",
" def on_train_begin(self):\n",
" self.epoch = 0\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" self.epoch += 1\n",
" self.update_lr()\n",
" \n",
" def update_lr(self): \n",
" new_lr = self.calc_lr()\n",
" self.layer_opt.set_lrs([new_lr])\n",
" \n",
" def calc_lr(self):\n",
" if self.epoch < self.swa_start//2:\n",
" return self.init_lr\n",
" elif self.epoch > 0.9 * self.swa_start:\n",
" return self.swa_lr\n",
" else:\n",
" return self.init_lr - ((self.init_lr - self.swa_lr) / int(0.4 * self.swa_start) * (self.epoch - self.swa_start//2))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"lr=0.1\n",
"swa_lr = 0.01\n",
"wd = 3e-4\n",
"swa_start = 126"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"budget = 150\n",
"\n",
"# train 3 preresnet110 models with SWA training schedule\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
" learn.data, \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_{i}.txt')]\n",
" )\n",
" learn.save(f'swa_{i}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1.25 budgets\n",
"budget = 187\n",
"\n",
"# train 3 preresnet110 models with SWA training schedule and 1.25 budgets\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
" learn.data, \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_187_{i}.txt')]\n",
" )\n",
" learn.save(f'swa_187_{i}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1.5 budgets\n",
"budget = 225\n",
"\n",
"# train 3 preresnet110 models with SWA training schedule and 1.25 budgets\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
" learn.data, \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_225{i}.txt')]\n",
" )\n",
" learn.save(f'swa_225{i}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment