Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sean-adler/467321a23805736f8dc06a19157b0567 to your computer and use it in GitHub Desktop.
Save sean-adler/467321a23805736f8dc06a19157b0567 to your computer and use it in GitHub Desktop.
PyTorch ignite parameter scheduling
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create MNIST data loader + simple CNN"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import typing as t\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"from torch.optim import SGD\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.transforms import Compose, ToTensor, Normalize\n",
"from torchvision.datasets import MNIST\n",
"\n",
"from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
"from ignite.metrics import CategoricalAccuracy, Loss\n",
"from ignite.handlers.param_scheduler import CosineAnnealingScheduler, LinearScheduler"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_data_loaders(train_batch_size: int, val_batch_size: int) -> t.Tuple[DataLoader, DataLoader]:\n",
" data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n",
"\n",
" train_loader = DataLoader(MNIST(download=True,\n",
" root=\".\",\n",
" transform=data_transform,\n",
" train=True),\n",
" batch_size=train_batch_size,\n",
" shuffle=True)\n",
"\n",
" val_loader = DataLoader(MNIST(download=False,\n",
" root=\".\",\n",
" transform=data_transform,\n",
" train=False),\n",
" batch_size=val_batch_size,\n",
" shuffle=False)\n",
"\n",
" return (train_loader, val_loader)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super(CNN, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
" self.conv2_drop = nn.Dropout2d()\n",
" self.fc1 = nn.Linear(320, 50)\n",
" self.fc2 = nn.Linear(50, 10)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
" x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
" x = x.view(-1, 320)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.dropout(x, training=self.training)\n",
" x = self.fc2(x)\n",
" return F.log_softmax(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set up training loop"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def train(epochs: int,\n",
" train_batch_size: int,\n",
" val_batch_size: int,\n",
" lr: t.Optional[float] = 1e-2,\n",
" momentum: t.Optional[float] = 0.5,\n",
" log_interval: t.Optional[int] = 50,\n",
" random_seed: t.Optional[int] = 42,\n",
" handlers: t.Optional[t.Tuple] = ()\n",
" ) -> nn.Module:\n",
" \"\"\"\n",
" Instantiates and trains a CNN on MNIST.\n",
" \"\"\"\n",
" torch.manual_seed(random_seed)\n",
" np.random.seed(random_seed)\n",
"\n",
" model = CNN()\n",
" train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)\n",
" device = 'cpu'\n",
"\n",
" if torch.cuda.is_available():\n",
" model = model.cuda()\n",
" device = 'cuda'\n",
"\n",
" optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)\n",
" trainer = create_supervised_trainer(\n",
" model,\n",
" optimizer,\n",
" F.nll_loss,\n",
" device=device)\n",
" evaluator = create_supervised_evaluator(\n",
" model,\n",
" metrics={'accuracy': CategoricalAccuracy(), 'nll': Loss(F.nll_loss)},\n",
" device=device)\n",
"\n",
" @trainer.on(Events.ITERATION_COMPLETED)\n",
" def log_training_loss(engine):\n",
" i = (engine.state.iteration - 1) % len(train_loader) + 1\n",
" if i % log_interval == 0:\n",
" print(f\"[{engine.state.epoch}] {i}/{len(train_loader)} loss: {'%.2f' % engine.state.output}\")\n",
"\n",
" # Attach scheduler(s)\n",
" for handler_args in handlers:\n",
" (scheduler_cls, param_name, start_value, end_value, cycle_mult) = handler_args\n",
" handler = scheduler_cls(\n",
" optimizer, param_name, start_value, end_value, len(train_loader),\n",
" cycle_mult=cycle_mult, save_history=True)\n",
" trainer.add_event_handler(Events.ITERATION_COMPLETED, handler)\n",
"\n",
" @trainer.on(Events.EPOCH_COMPLETED)\n",
" def log_validation_results(engine):\n",
" evaluator.run(val_loader)\n",
" metrics = evaluator.state.metrics\n",
" avg_accuracy = metrics['accuracy']\n",
" avg_nll = metrics['nll']\n",
" print(\"Validation Accuracy: {:.2f} Loss: {:.2f}\\n\".format(avg_accuracy, avg_nll))\n",
"\n",
" trainer.run(train_loader, max_epochs=epochs)\n",
" \n",
" return (model, trainer.state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train model over different schedules"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = ((13, 7))\n",
"plt.rcParams['image.interpolation'] = 'nearest'\n",
"plt.rcParams['image.cmap'] = 'gray'\n",
"\n",
"\n",
"def plot_metric(state, field, label):\n",
" plt.plot(state.param_history[field])\n",
" plt.xlabel('batch')\n",
" plt.ylabel(label)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1] 250/938 loss: 1.59\n",
"[1] 500/938 loss: 0.88\n",
"[1] 750/938 loss: 0.57\n",
"Validation Accuracy: 0.94 Loss: 0.23\n",
"\n",
"[2] 250/938 loss: 0.55\n",
"[2] 500/938 loss: 0.27\n",
"[2] 750/938 loss: 0.75\n",
"Validation Accuracy: 0.96 Loss: 0.14\n",
"\n",
"[3] 250/938 loss: 0.24\n",
"[3] 500/938 loss: 0.18\n",
"[3] 750/938 loss: 0.18\n",
"Validation Accuracy: 0.97 Loss: 0.10\n",
"\n",
"[4] 250/938 loss: 0.16\n",
"[4] 500/938 loss: 0.39\n",
"[4] 750/938 loss: 0.29\n",
"Validation Accuracy: 0.97 Loss: 0.09\n",
"\n",
"[5] 250/938 loss: 0.21\n",
"[5] 500/938 loss: 0.27\n",
"[5] 750/938 loss: 0.25\n",
"Validation Accuracy: 0.97 Loss: 0.08\n",
"\n",
"[6] 250/938 loss: 0.35\n",
"[6] 500/938 loss: 0.36\n",
"[6] 750/938 loss: 0.30\n",
"Validation Accuracy: 0.98 Loss: 0.08\n",
"\n",
"[7] 250/938 loss: 0.10\n",
"[7] 500/938 loss: 0.21\n",
"[7] 750/938 loss: 0.18\n",
"Validation Accuracy: 0.98 Loss: 0.07\n",
"\n"
]
}
],
"source": [
"# No parameter scheduling.\n",
"lr = 1e-3\n",
"momentum = 0.95\n",
"\n",
"model_1, state_1 = train(7, 64, 1000, lr=lr, momentum=momentum, log_interval=250)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1] 250/938 loss: 0.94\n",
"[1] 500/938 loss: 0.57\n",
"[1] 750/938 loss: 0.21\n",
"Validation Accuracy: 0.96 Loss: 0.12\n",
"\n",
"[2] 250/938 loss: 0.51\n",
"[2] 500/938 loss: 0.35\n",
"[2] 750/938 loss: 0.48\n",
"Validation Accuracy: 0.97 Loss: 0.08\n",
"\n",
"[3] 250/938 loss: 0.28\n",
"[3] 500/938 loss: 0.25\n",
"[3] 750/938 loss: 0.13\n",
"Validation Accuracy: 0.98 Loss: 0.07\n",
"\n",
"[4] 250/938 loss: 0.23\n",
"[4] 500/938 loss: 0.47\n",
"[4] 750/938 loss: 0.18\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n",
"[5] 250/938 loss: 0.13\n",
"[5] 500/938 loss: 0.28\n",
"[5] 750/938 loss: 0.13\n",
"Validation Accuracy: 0.98 Loss: 0.05\n",
"\n",
"[6] 250/938 loss: 0.30\n",
"[6] 500/938 loss: 0.25\n",
"[6] 750/938 loss: 0.19\n",
"Validation Accuracy: 0.99 Loss: 0.05\n",
"\n",
"[7] 250/938 loss: 0.14\n",
"[7] 500/938 loss: 0.19\n",
"[7] 750/938 loss: 0.20\n",
"Validation Accuracy: 0.99 Loss: 0.04\n",
"\n"
]
}
],
"source": [
"# Linearly cycle LR up and down.\n",
"lr = 1e-3\n",
"momentum = 0\n",
"\n",
"handlers = [\n",
" (LinearScheduler, 'lr', lr, lr * 100, 1)\n",
"]\n",
"\n",
"model_2, state_2 = train(7, 64, 1000, lr=lr, momentum=momentum, log_interval=250, handlers=handlers)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_metric(state_2, 'lr', 'learning rate')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1] 250/938 loss: 0.96\n",
"[1] 500/938 loss: 0.55\n",
"[1] 750/938 loss: 0.22\n",
"Validation Accuracy: 0.96 Loss: 0.12\n",
"\n",
"[2] 250/938 loss: 0.49\n",
"[2] 500/938 loss: 0.25\n",
"[2] 750/938 loss: 0.59\n",
"Validation Accuracy: 0.97 Loss: 0.10\n",
"\n",
"[3] 250/938 loss: 0.28\n",
"[3] 500/938 loss: 0.21\n",
"[3] 750/938 loss: 0.17\n",
"Validation Accuracy: 0.98 Loss: 0.07\n",
"\n",
"[4] 250/938 loss: 0.22\n",
"[4] 500/938 loss: 0.42\n",
"[4] 750/938 loss: 0.13\n",
"Validation Accuracy: 0.98 Loss: 0.07\n",
"\n",
"[5] 250/938 loss: 0.14\n",
"[5] 500/938 loss: 0.36\n",
"[5] 750/938 loss: 0.14\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n",
"[6] 250/938 loss: 0.23\n",
"[6] 500/938 loss: 0.28\n",
"[6] 750/938 loss: 0.15\n",
"Validation Accuracy: 0.98 Loss: 0.05\n",
"\n",
"[7] 250/938 loss: 0.19\n",
"[7] 500/938 loss: 0.15\n",
"[7] 750/938 loss: 0.26\n",
"Validation Accuracy: 0.99 Loss: 0.04\n",
"\n"
]
}
],
"source": [
"# Linearly cycle LR up and down, with an increasing cycle length.\n",
"lr = 1e-3\n",
"momentum = 0\n",
"\n",
"handlers = [\n",
" (LinearScheduler, 'lr', lr, lr * 100, 2)\n",
"]\n",
"\n",
"model_3, state_3 = train(7, 64, 1000, lr=lr, momentum=momentum, log_interval=250, handlers=handlers)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_metric(state_3, 'lr', 'learning rate')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1] 250/938 loss: 0.73\n",
"[1] 500/938 loss: 0.47\n",
"[1] 750/938 loss: 0.34\n",
"Validation Accuracy: 0.97 Loss: 0.10\n",
"\n",
"[2] 250/938 loss: 0.43\n",
"[2] 500/938 loss: 0.47\n",
"[2] 750/938 loss: 0.68\n",
"Validation Accuracy: 0.96 Loss: 0.13\n",
"\n",
"[3] 250/938 loss: 0.26\n",
"[3] 500/938 loss: 0.19\n",
"[3] 750/938 loss: 0.12\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n",
"[4] 250/938 loss: 0.24\n",
"[4] 500/938 loss: 0.31\n",
"[4] 750/938 loss: 0.24\n",
"Validation Accuracy: 0.98 Loss: 0.08\n",
"\n",
"[5] 250/938 loss: 0.32\n",
"[5] 500/938 loss: 0.32\n",
"[5] 750/938 loss: 0.27\n",
"Validation Accuracy: 0.96 Loss: 0.12\n",
"\n",
"[6] 250/938 loss: 0.70\n",
"[6] 500/938 loss: 0.38\n",
"[6] 750/938 loss: 0.44\n",
"Validation Accuracy: 0.97 Loss: 0.09\n",
"\n",
"[7] 250/938 loss: 0.07\n",
"[7] 500/938 loss: 0.26\n",
"[7] 750/938 loss: 0.44\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n"
]
}
],
"source": [
"# Linearly cycle LR and momentum up and down, with an increasing cycle length.\n",
"lr = 1e-3\n",
"momentum = 0.95\n",
"\n",
"handlers = [\n",
" (LinearScheduler, 'lr', lr, lr * 100, 2),\n",
" (LinearScheduler, 'momentum', momentum, momentum - 0.2, 2)\n",
"]\n",
"\n",
"model_4, state_4 = train(7, 64, 1000, lr=lr, momentum=momentum, log_interval=250, handlers=handlers)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_metric(state_4, 'lr', 'learning rate')\n",
"plot_metric(state_4, 'momentum', 'momentum')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1] 250/938 loss: 0.67\n",
"[1] 500/938 loss: 0.48\n",
"[1] 750/938 loss: 0.22\n",
"Validation Accuracy: 0.96 Loss: 0.11\n",
"\n",
"[2] 250/938 loss: 0.46\n",
"[2] 500/938 loss: 0.26\n",
"[2] 750/938 loss: 0.43\n",
"Validation Accuracy: 0.97 Loss: 0.08\n",
"\n",
"[3] 250/938 loss: 0.13\n",
"[3] 500/938 loss: 0.21\n",
"[3] 750/938 loss: 0.07\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n",
"[4] 250/938 loss: 0.22\n",
"[4] 500/938 loss: 0.37\n",
"[4] 750/938 loss: 0.19\n",
"Validation Accuracy: 0.98 Loss: 0.06\n",
"\n",
"[5] 250/938 loss: 0.13\n",
"[5] 500/938 loss: 0.29\n",
"[5] 750/938 loss: 0.14\n",
"Validation Accuracy: 0.98 Loss: 0.05\n",
"\n",
"[6] 250/938 loss: 0.18\n",
"[6] 500/938 loss: 0.24\n",
"[6] 750/938 loss: 0.05\n",
"Validation Accuracy: 0.99 Loss: 0.05\n",
"\n",
"[7] 250/938 loss: 0.20\n",
"[7] 500/938 loss: 0.12\n",
"[7] 750/938 loss: 0.24\n",
"Validation Accuracy: 0.99 Loss: 0.04\n",
"\n"
]
}
],
"source": [
"# Perform cosine annealing with warm restarts to LR.\n",
"lr = 1e-3\n",
"momentum = 0\n",
"\n",
"handlers = [\n",
" (CosineAnnealingScheduler, 'lr', lr, lr * 100, 2)\n",
"]\n",
"\n",
"model_5, state_5 = train(7, 64, 1000, lr=lr, momentum=momentum, log_interval=250, handlers=handlers)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_metric(state_5, 'lr', 'learning rate')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment