Last active July 31, 2019 02:13
SWA testing
"cells": [
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"'''Resnet for cifar dataset.\n",
"Ported form\n",
"(c) YANG, Wei\n",
"import torch.nn as nn\n",
"import math\n",
"__all__ = ['preresnet']\n",
"def conv3x3(in_planes, out_planes, stride=1):\n",
" \"3x3 convolution with padding\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
"class BasicBlock(nn.Module):\n",
" expansion = 1\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(BasicBlock, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = conv3x3(planes, planes)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
" def forward(self, x):\n",
" residual = x\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
" out += residual\n",
" return out\n",
"class Bottleneck(nn.Module):\n",
" expansion = 4\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(Bottleneck, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
" self.bn3 = nn.BatchNorm2d(planes)\n",
" self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
" def forward(self, x):\n",
" residual = x\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
" out = self.bn3(out)\n",
" out = self.relu(out)\n",
" out = self.conv3(out)\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
" out += residual\n",
" return out\n",
"class PreResNet(nn.Module):\n",
" def __init__(self, depth, num_classes=1000):\n",
" super(PreResNet, self).__init__()\n",
" # Model type specifies number of layers for CIFAR-10 model\n",
" assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n",
" n = (depth - 2) // 6\n",
" block = Bottleneck if depth >=44 else BasicBlock\n",
" self.inplanes = 16\n",
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n",
" bias=False)\n",
" self.layer1 = self._make_layer(block, 16, n)\n",
" self.layer2 = self._make_layer(block, 32, n, stride=2)\n",
" self.layer3 = self._make_layer(block, 64, n, stride=2)\n",
" = nn.BatchNorm2d(64 * block.expansion)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.avgpool = nn.AvgPool2d(8)\n",
" self.fc = nn.Linear(64 * block.expansion, num_classes)\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
", math.sqrt(2. / n))\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" def _make_layer(self, block, planes, blocks, stride=1):\n",
" downsample = None\n",
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
" downsample = nn.Sequential(\n",
" nn.Conv2d(self.inplanes, planes * block.expansion,\n",
" kernel_size=1, stride=stride, bias=False),\n",
" )\n",
" layers = []\n",
" layers.append(block(self.inplanes, planes, stride, downsample))\n",
" self.inplanes = planes * block.expansion\n",
" for i in range(1, blocks):\n",
" layers.append(block(self.inplanes, planes))\n",
" return nn.Sequential(*layers)\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.layer1(x) # 32x32\n",
" x = self.layer2(x) # 16x16\n",
" x = self.layer3(x) # 8x8\n",
" x =\n",
" x = self.relu(x)\n",
" x = self.avgpool(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc(x)\n",
" return x\n",
"def preresnet(**kwargs):\n",
" \"\"\"\n",
" Constructs a ResNet model.\n",
" \"\"\"\n",
" return PreResNet(**kwargs)"
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from fastai.conv_learner import *\n",
"from fastai.model import fit\n",
"from fastai.core import SGD_Momentum\n",
"PATH = \"data/cifar10/\"\n",
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))"
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_data(sz,bs):\n",
" tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
" return ImageClassifierData.from_paths(PATH, trn_name='train_', val_name='test_', tfms=tfms, bs=bs)"
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,bs)"
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import Callback\n",
"class SgdLrUpdater(Callback):\n",
" def __init__(self, layer_opt, init_lr, budget):\n",
" self.layer_opt=layer_opt\n",
" self.init_lr=init_lr\n",
" self.budget=budget\n",
" \n",
" def on_train_begin(self):\n",
" self.epoch = 0\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" self.epoch += 1\n",
" self.update_lr()\n",
" \n",
" def update_lr(self): \n",
" new_lr = self.calc_lr()\n",
" self.layer_opt.set_lrs([new_lr])\n",
" \n",
" def calc_lr(self):\n",
" if self.epoch < self.budget//2:\n",
" return self.init_lr\n",
" elif self.epoch > 0.9 * self.budget:\n",
" return 0.01 * self.init_lr\n",
" else:\n",
" return self.init_lr - (self.init_lr * 0.99 / int(0.4 * self.budget) * (self.epoch - self.budget//2))"
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import LoggingCallback"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5be0b56e68c5499091096a160172daa3",
"version_major": 2,
"version_minor": 0
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=150), HTML(value='')))"
"metadata": {},
"output_type": "display_data"
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 1.321136 1.399471 0.486847 \n",
" 1 1.042043 1.044005 0.632516 \n",
" 2 0.871521 0.895404 0.688588 \n",
" 3 0.743716 0.824015 0.723398 \n",
" 4 0.679371 0.934189 0.70085 \n",
" 5 0.606691 0.70843 0.767405 \n",
" 6 0.574801 0.637944 0.782437 \n",
" 7 0.550256 0.573916 0.802116 \n",
" 8 0.525207 0.655205 0.77769 \n",
" 9 0.518154 0.669004 0.772053 \n",
" 10 0.512644 0.684118 0.777393 \n",
" 11 0.488844 0.550262 0.810324 \n",
" 2%|▏ | 7/391 [00:08<08:13, 1.28s/it, loss=0.483]"
"source": [
"wd = 3e-4\n",
"budget = 150\n",
"# train 3 preresnet110 models with normal SGD with momentum\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
", \n",
" layer_opt, \n",
" budget, \n",
" callbacks=[SgdLrUpdater(layer_opt, lr, budget), LoggingCallback(f'{PATH}logs/sgd_{i}.txt')]\n",
" )\n",
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from fastai.sgdr import Callback\n",
"class SwaLrUpdater(Callback):\n",
" def __init__(self, layer_opt, init_lr, budget, swa_start, swa_lr):\n",
" self.layer_opt=layer_opt\n",
" self.init_lr=init_lr\n",
" self.budget=budget\n",
" self.swa_start=swa_start\n",
" self.swa_lr=swa_lr\n",
" \n",
" def on_train_begin(self):\n",
" self.epoch = 0\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" self.epoch += 1\n",
" self.update_lr()\n",
" \n",
" def update_lr(self): \n",
" new_lr = self.calc_lr()\n",
" self.layer_opt.set_lrs([new_lr])\n",
" \n",
" def calc_lr(self):\n",
" if self.epoch < self.swa_start//2:\n",
" return self.init_lr\n",
" elif self.epoch > 0.9 * self.swa_start:\n",
" return self.swa_lr\n",
" else:\n",
" return self.init_lr - ((self.init_lr - self.swa_lr) / int(0.4 * self.swa_start) * (self.epoch - self.swa_start//2))"
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"swa_lr = 0.01\n",
"wd = 3e-4\n",
"swa_start = 126"
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"budget = 150\n",
"# train 3 preresnet110 models with SWA training schedule\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
", \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_{i}.txt')]\n",
" )\n",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1.25 budgets\n",
"budget = 187\n",
"# train 3 preresnet110 models with SWA training schedule and 1.25 budgets\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
", \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_187_{i}.txt')]\n",
" )\n",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1.5 budgets\n",
"budget = 225\n",
"# train 3 preresnet110 models with SWA training schedule and 1.25 budgets\n",
"for i in range(3):\n",
" preresnet110 = preresnet(depth=110, num_classes=10)\n",
" learn = ConvLearner.from_model_data(preresnet110, data)\n",
" layer_opt = learn.get_layer_opt([lr], [wd])\n",
" learn.crit = F.cross_entropy\n",
" learn.fit_gen(\n",
" learn.model, \n",
", \n",
" layer_opt, \n",
" budget,\n",
" use_swa=True,\n",
" swa_start=swa_start,\n",
" swa_eval_freq=1,\n",
" callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_225{i}.txt')]\n",
" )\n",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"nbformat": 4,
"nbformat_minor": 2
