Skip to content

Instantly share code, notes, and snippets.

@wdhorton
Last active July 31, 2019 02:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save wdhorton/447f9aecc209ee6fc6ddab3122f6b685 to your computer and use it in GitHub Desktop.
Save wdhorton/447f9aecc209ee6fc6ddab3122f6b685 to your computer and use it in GitHub Desktop.
SWA testing
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
"env: CUDA_VISIBLE_DEVICES=1\n"
]
}
],
"source": [
"%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
"%env CUDA_VISIBLE_DEVICES=1"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from fastai.imports import *\n",
"from fastai.sgdr import Callback\n",
"\n",
"from fastai.core import SimpleNet\n",
"from fastai.conv_learner import *"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"PATH = \"data/cifar10/\"\n",
"\n",
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))\n",
"\n",
"def get_data(sz,bs):\n",
" tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
" return ImageClassifierData.from_paths(PATH, trn_name='train_', val_name='test_', tfms=tfms, bs=bs)\n",
"\n",
"bs=128\n",
"data = get_data(32,bs)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"net = SimpleNet([32*32*3, 40, 10])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"learn = ConvLearner.from_model_data(net, data)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"lr = 2e-2"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cd7c2529f1742b9881c57c89fbd8030",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy swa_loss swa_accuracy \n",
" 0 1.774172 1.649104 0.413074 1.649104 0.413074 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[1.6491035, 0.4130735759493671, 1.6491035, 0.4130735759493671]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.fit(lr, 1, use_swa=True)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SimpleNet(\n",
" (layers): ModuleList(\n",
" (0): Linear(in_features=3072, out_features=40, bias=True)\n",
" (1): Linear(in_features=40, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.swa_model"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
" 1 1 1 ... 1 1 1\n",
" 1 1 1 ... 1 1 1\n",
" 1 1 1 ... 1 1 1\n",
" ... ⋱ ... \n",
" 1 1 1 ... 1 1 1\n",
" 1 1 1 ... 1 1 1\n",
" 1 1 1 ... 1 1 1\n",
"[torch.cuda.ByteTensor of size 40x3072 (GPU 0)]\n",
"\n",
"Variable containing:\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
"[torch.cuda.ByteTensor of size 40 (GPU 0)]\n",
"\n",
"Variable containing:\n",
"\n",
"Columns 0 to 12 \n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
"\n",
"Columns 13 to 25 \n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
"\n",
"Columns 26 to 38 \n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
"\n",
"Columns 39 to 39 \n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
"[torch.cuda.ByteTensor of size 10x40 (GPU 0)]\n",
"\n",
"Variable containing:\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
" 1\n",
"[torch.cuda.ByteTensor of size 10 (GPU 0)]\n",
"\n"
]
}
],
"source": [
"# verifies that it's equal to the first model's parameters after 1 epoch\n",
"for p1, p2 in zip(learn.model.parameters(), learn.swa_model.parameters()):\n",
" print(p1 == p2)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"params = []\n",
"\n",
"class SaveModelParams(Callback):\n",
" def __init__(self, model):\n",
" self.model = model\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" params.append([p.data.cpu().numpy() for p in self.model.parameters()])"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f28bfe6af4d4783a26a3f8df00ae1be",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy swa_loss swa_accuracy \n",
" 0 1.773514 1.691156 0.404371 1.691156 0.404371 \n",
" 1 1.737232 1.603997 0.432259 \n",
" 2 1.68089 1.644307 0.417227 1.513425 0.463212 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[1.6443068, 0.41722705696202533, 1.513425, 0.4632120253164557]"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net2 = SimpleNet([32*32*3, 40, 10])\n",
"learn2 = ConvLearner.from_model_data(net2, data)\n",
"lr = 2e-2\n",
"learn2.fit(lr, 3, use_swa=True, callbacks=[SaveModelParams(learn2.model)])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"print(len(params))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"swa_model_params = [p.data.cpu().numpy() for p in learn2.swa_model.parameters()]"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" ...\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]]\n",
"[ True True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True True\n",
" True True True True]\n",
"[[ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]]\n",
"[ True True True True True True True True True True]\n"
]
}
],
"source": [
"for p_model1, p_model2, p_model3, p_swa_model in zip(*params, swa_model_params):\n",
" # check for equality up to a certain tolerance\n",
" print(np.isclose(p_swa_model, np.mean(np.stack([p_model1, p_model2, p_model3]), axis=0)))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"learn.save('test')\n",
"learn.load('test')"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1.70664 -1.05492 -3.87068 ... -3.25529 -2.24639 -1.74912]\n",
" [-1.96794 -4.14807 -2.40735 ... -1.167 -4.95536 -3.72812]\n",
" [-0.61625 -4.61556 -1.19078 ... -6.56123 -2.47553 -4.84968]\n",
" ...\n",
" [-3.75604 -2.18386 -3.12696 ... -2.74888 -4.26681 -0.53816]\n",
" [-2.22326 -1.58195 -4.6265 ... -2.21887 -4.00163 -1.10925]\n",
" [-2.19951 -0.71748 -5.376 ... -5.23997 -1.85072 -1.54509]] [[-1.28266 -1.90767 -3.08345 ... -2.65925 -1.99473 -1.40246]\n",
" [-1.29219 -5.05158 -2.04013 ... -1.32583 -4.00925 -3.2551 ]\n",
" [-0.4586 -5.46795 -1.64899 ... -5.80076 -2.18757 -5.69979]\n",
" ...\n",
" [-3.23602 -1.76301 -3.36458 ... -2.50504 -3.53825 -0.70707]\n",
" [-1.65576 -2.66782 -3.46807 ... -1.68105 -3.06506 -1.33282]\n",
" [-1.79727 -1.17396 -4.66499 ... -4.85271 -1.85708 -1.09175]]\n"
]
}
],
"source": [
"preds = learn2.predict()\n",
"preds_swa = learn2.predict(use_swa=True)\n",
"print(preds, preds_swa)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"params = []\n",
"\n",
"class SaveModelParams(Callback):\n",
" def __init__(self, model):\n",
" self.model = model\n",
" \n",
" def on_epoch_end(self, metrics):\n",
" params.append([p.data.cpu().numpy() for p in self.model.parameters()])"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f2d4cdecf5b14522b34de3036af7fbe4",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=6), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy swa_loss swa_accuracy \n",
" 0 1.771523 1.649534 0.413074 \n",
" 1 1.722752 1.678871 0.400218 \n",
" 2 1.698014 1.648443 0.416337 1.648443 0.416337 \n",
" 3 1.691304 1.582978 0.437302 \n",
" 4 1.706535 1.555867 0.439082 \n",
" 5 1.656273 1.628426 0.435423 1.449969 0.484474 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[1.6284263, 0.4354232594936709, 1.4499689, 0.4844738924050633]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net = SimpleNet([32*32*3, 40, 10])\n",
"learn = ConvLearner.from_model_data(net, data)\n",
"lr = 2e-2\n",
"learn.fit(lr, 6, use_swa=True, swa_start=3, callbacks=[SaveModelParams(learn.model)])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"swa_model_params = [p.data.cpu().numpy() for p in learn.swa_model.parameters()]"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6\n"
]
}
],
"source": [
"print(len(params))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" ...\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]\n",
" [ True True True ... True True True]]\n",
"[ True True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True True\n",
" True True True True]\n",
"[[ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]\n",
" [ True True True True True True True True True True True True True True True True True\n",
" True True True True True True True True True True True True True True True True True\n",
" True True True True True True]]\n",
"[ True True True True True True True True True True]\n"
]
}
],
"source": [
"for *p_models, p_swa_model in zip(*params[2:], swa_model_params):\n",
" # check for equality up to a certain tolerance\n",
" print(np.isclose(p_swa_model, np.mean(np.stack(p_models), axis=0)))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SimpleNet(\n",
" (layers): ModuleList(\n",
" (0): Linear(in_features=3072, out_features=40, bias=True)\n",
" (1): Linear(in_features=40, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net = SimpleNet([32*32*3, 40, 10])\n",
"net"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"36\n"
]
}
],
"source": [
"from fastai.swa import collect_bn_modules\n",
"\n",
"net_bn = []\n",
"net.apply(lambda m: collect_bn_modules(m, net_bn))\n",
"print(len(net_bn))\n",
"\n",
"resnet_bn = []\n",
"resnet34().apply(lambda m: collect_bn_modules(m, resnet_bn))\n",
"print(len(resnet_bn))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"\n",
"'''Resnet for cifar dataset.\n",
"Ported form\n",
"https://github.com/facebook/fb.resnet.torch\n",
"and\n",
"https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n",
"(c) YANG, Wei\n",
"'''\n",
"import torch.nn as nn\n",
"import math\n",
"\n",
"\n",
"__all__ = ['preresnet']\n",
"\n",
"def conv3x3(in_planes, out_planes, stride=1):\n",
" \"3x3 convolution with padding\"\n",
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
"\n",
"\n",
"class BasicBlock(nn.Module):\n",
" expansion = 1\n",
"\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(BasicBlock, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = conv3x3(planes, planes)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
"\n",
" def forward(self, x):\n",
" residual = x\n",
"\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
"\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
"\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
"\n",
" out += residual\n",
"\n",
" return out\n",
"\n",
"\n",
"class Bottleneck(nn.Module):\n",
" expansion = 4\n",
"\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
" super(Bottleneck, self).__init__()\n",
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
" self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(planes)\n",
" self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
" padding=1, bias=False)\n",
" self.bn3 = nn.BatchNorm2d(planes)\n",
" self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.downsample = downsample\n",
" self.stride = stride\n",
"\n",
" def forward(self, x):\n",
" residual = x\n",
"\n",
" out = self.bn1(x)\n",
" out = self.relu(out)\n",
" out = self.conv1(out)\n",
"\n",
" out = self.bn2(out)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
"\n",
" out = self.bn3(out)\n",
" out = self.relu(out)\n",
" out = self.conv3(out)\n",
"\n",
" if self.downsample is not None:\n",
" residual = self.downsample(x)\n",
"\n",
" out += residual\n",
"\n",
" return out\n",
"\n",
"\n",
"class PreResNet(nn.Module):\n",
"\n",
" def __init__(self, depth, num_classes=1000):\n",
" super(PreResNet, self).__init__()\n",
" # Model type specifies number of layers for CIFAR-10 model\n",
" assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n",
" n = (depth - 2) // 6\n",
"\n",
" block = Bottleneck if depth >=44 else BasicBlock\n",
"\n",
" self.inplanes = 16\n",
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n",
" bias=False)\n",
" self.layer1 = self._make_layer(block, 16, n)\n",
" self.layer2 = self._make_layer(block, 32, n, stride=2)\n",
" self.layer3 = self._make_layer(block, 64, n, stride=2)\n",
" self.bn = nn.BatchNorm2d(64 * block.expansion)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.avgpool = nn.AvgPool2d(8)\n",
" self.fc = nn.Linear(64 * block.expansion, num_classes)\n",
"\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
" m.weight.data.normal_(0, math.sqrt(2. / n))\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" m.weight.data.fill_(1)\n",
" m.bias.data.zero_()\n",
"\n",
" def _make_layer(self, block, planes, blocks, stride=1):\n",
" downsample = None\n",
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
" downsample = nn.Sequential(\n",
" nn.Conv2d(self.inplanes, planes * block.expansion,\n",
" kernel_size=1, stride=stride, bias=False),\n",
" )\n",
"\n",
" layers = []\n",
" layers.append(block(self.inplanes, planes, stride, downsample))\n",
" self.inplanes = planes * block.expansion\n",
" for i in range(1, blocks):\n",
" layers.append(block(self.inplanes, planes))\n",
"\n",
" return nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
"\n",
" x = self.layer1(x) # 32x32\n",
" x = self.layer2(x) # 16x16\n",
" x = self.layer3(x) # 8x8\n",
" x = self.bn(x)\n",
" x = self.relu(x)\n",
"\n",
" x = self.avgpool(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc(x)\n",
"\n",
" return x\n",
"\n",
"\n",
"def preresnet(**kwargs):\n",
" \"\"\"\n",
" Constructs a ResNet model.\n",
" \"\"\"\n",
" return PreResNet(**kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"from fastai.swa import fix_batchnorm, collect_bn_modules\n",
"\n",
"def test_momentum_preserved(model):\n",
" bn_modules = []\n",
" model.apply(lambda module: collect_bn_modules(module, bn_modules))\n",
" momenta_before = [m.momentum for m in bn_modules]\n",
" fix_batchnorm(preresnet110, data.trn_dl)\n",
" \n",
" for module, momentum_before in zip(bn_modules, momenta_before):\n",
" assert module.momentum == momentum_before"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"model = preresnet110 = preresnet(depth=110, num_classes=10).cuda()\n",
"test_momentum_preserved(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment