Skip to content

Instantly share code, notes, and snippets.

@Cysu
Last active February 24, 2017 08:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Cysu/320960298b6ccaedb778a2a4de2dd2db to your computer and use it in GitHub Desktop.
Save Cysu/320960298b6ccaedb778a2a4de2dd2db to your computer and use it in GitHub Desktop.
PyTorch Tutorial
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Simple CNN with PyTorch\n",
"\n",
"In this notebook example, we will walk through how to train a simple CNN to classify images.\n",
"\n",
"We will rely on the following modules, including torch and torchvision."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets\n",
"from torchvision import transforms"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## 1. Data Loader\n",
"\n",
"The first step is to create a data loader.\n",
"\n",
"A data loader can be treated as a list (or iterator, technically). Each time it will provide a minibatch of (img, label) pairs."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
]
}
],
"source": [
"# Choose a dataset -- CIFAR10 for example\n",
"dataset = datasets.CIFAR10(root='data', train=True, download=True)\n",
"\n",
"# Set how the input images will be transformed\n",
"dataset.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.491, 0.482, 0.447],\n",
" std=[0.247, 0.244, 0.262])\n",
"])\n",
"\n",
"# Create a data loader\n",
"train_loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## 2. Model\n",
"\n",
"The second step is to define our model.\n",
"\n",
"We will use a simple CNN with conv(3x3) -> bn -> relu -> pool(4x4) -> fc.\n",
"\n",
"In PyTorch, a model is defined by a subclass of nn.Module. It has two methods:\n",
"\n",
" - `__init__`: constructor. Create layers here. Note that we don't define the connections between layers in this function.\n",
"\n",
" - `forward(x)`: forward function. Receives an input variable `x`. Returns a output variable. Note that we actually connect the layers here dynamically.\n",
"\n",
"Comparing with caffe, we no longer need to implement the backward function. The computational graph will be built implicitly based on the forward operations, and the gradients can be automatically computed."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"class SimpleCNN(nn.Module):\n",
" def __init__(self):\n",
" super(SimpleCNN, self).__init__() # Call parent class's constructor\n",
"\n",
" self.conv = nn.Conv2d(3, 8, kernel_size=3, padding=1)\n",
" self.bn = nn.BatchNorm2d(8)\n",
" self.relu = nn.ReLU()\n",
" self.pool = nn.MaxPool2d(kernel_size=4, stride=4)\n",
" self.fc = nn.Linear(64*8, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x) # When a nn.Module is called, it will compute the result\n",
" x = self.bn(x)\n",
" x = self.relu(x)\n",
" x = self.pool(x)\n",
" x = x.view(x.size(0), -1) # Reshape from (N, C, H, W) to (N, CxHxW)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"model = SimpleCNN()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## 3. Loss and Optimizer\n",
"\n",
"The third step is to define the loss function and the optimization algorithm."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"optimizer = torch.optim.SGD(model.parameters(), 0.01,\n",
" momentum=0.9, weight_decay=5e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## 4. Start training\n",
"\n",
"The next step is to start the training process."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 batch 0/196 loss 2.375\n",
"epoch 1 batch 10/196 loss 2.189\n",
"epoch 1 batch 20/196 loss 1.975\n",
"epoch 1 batch 30/196 loss 1.873\n",
"epoch 1 batch 40/196 loss 1.822\n",
"epoch 1 batch 50/196 loss 1.787\n",
"epoch 1 batch 60/196 loss 1.706\n",
"epoch 1 batch 70/196 loss 1.640\n",
"epoch 1 batch 80/196 loss 1.672\n",
"epoch 1 batch 90/196 loss 1.581\n",
"epoch 1 batch 100/196 loss 1.553\n",
"epoch 1 batch 110/196 loss 1.588\n",
"epoch 1 batch 120/196 loss 1.545\n",
"epoch 1 batch 130/196 loss 1.408\n",
"epoch 1 batch 140/196 loss 1.588\n",
"epoch 1 batch 150/196 loss 1.477\n",
"epoch 1 batch 160/196 loss 1.491\n",
"epoch 1 batch 170/196 loss 1.458\n",
"epoch 1 batch 180/196 loss 1.483\n",
"epoch 1 batch 190/196 loss 1.439\n",
"epoch 2 batch 0/196 loss 1.490\n",
"epoch 2 batch 10/196 loss 1.498\n",
"epoch 2 batch 20/196 loss 1.385\n",
"epoch 2 batch 30/196 loss 1.343\n",
"epoch 2 batch 40/196 loss 1.346\n",
"epoch 2 batch 50/196 loss 1.347\n",
"epoch 2 batch 60/196 loss 1.452\n",
"epoch 2 batch 70/196 loss 1.461\n",
"epoch 2 batch 80/196 loss 1.310\n",
"epoch 2 batch 90/196 loss 1.331\n",
"epoch 2 batch 100/196 loss 1.495\n",
"epoch 2 batch 110/196 loss 1.390\n",
"epoch 2 batch 120/196 loss 1.391\n",
"epoch 2 batch 130/196 loss 1.425\n",
"epoch 2 batch 140/196 loss 1.255\n",
"epoch 2 batch 150/196 loss 1.313\n",
"epoch 2 batch 160/196 loss 1.324\n",
"epoch 2 batch 170/196 loss 1.253\n",
"epoch 2 batch 180/196 loss 1.353\n",
"epoch 2 batch 190/196 loss 1.409\n",
"epoch 3 batch 0/196 loss 1.329\n",
"epoch 3 batch 10/196 loss 1.331\n",
"epoch 3 batch 20/196 loss 1.352\n",
"epoch 3 batch 30/196 loss 1.249\n",
"epoch 3 batch 40/196 loss 1.249\n",
"epoch 3 batch 50/196 loss 1.240\n",
"epoch 3 batch 60/196 loss 1.314\n",
"epoch 3 batch 70/196 loss 1.194\n",
"epoch 3 batch 80/196 loss 1.216\n",
"epoch 3 batch 90/196 loss 1.260\n",
"epoch 3 batch 100/196 loss 1.251\n",
"epoch 3 batch 110/196 loss 1.348\n",
"epoch 3 batch 120/196 loss 1.290\n",
"epoch 3 batch 130/196 loss 1.208\n",
"epoch 3 batch 140/196 loss 1.331\n",
"epoch 3 batch 150/196 loss 1.349\n",
"epoch 3 batch 160/196 loss 1.330\n",
"epoch 3 batch 170/196 loss 1.259\n",
"epoch 3 batch 180/196 loss 1.340\n",
"epoch 3 batch 190/196 loss 1.336\n",
"epoch 4 batch 0/196 loss 1.213\n",
"epoch 4 batch 10/196 loss 1.212\n",
"epoch 4 batch 20/196 loss 1.298\n",
"epoch 4 batch 30/196 loss 1.189\n",
"epoch 4 batch 40/196 loss 1.290\n",
"epoch 4 batch 50/196 loss 1.143\n",
"epoch 4 batch 60/196 loss 1.256\n",
"epoch 4 batch 70/196 loss 1.256\n",
"epoch 4 batch 80/196 loss 1.170\n",
"epoch 4 batch 90/196 loss 1.311\n",
"epoch 4 batch 100/196 loss 1.273\n",
"epoch 4 batch 110/196 loss 1.363\n",
"epoch 4 batch 120/196 loss 1.213\n",
"epoch 4 batch 130/196 loss 1.266\n",
"epoch 4 batch 140/196 loss 1.286\n",
"epoch 4 batch 150/196 loss 1.217\n",
"epoch 4 batch 160/196 loss 1.193\n",
"epoch 4 batch 170/196 loss 1.112\n",
"epoch 4 batch 180/196 loss 1.228\n",
"epoch 4 batch 190/196 loss 1.241\n",
"epoch 5 batch 0/196 loss 1.271\n",
"epoch 5 batch 10/196 loss 1.183\n",
"epoch 5 batch 20/196 loss 1.132\n",
"epoch 5 batch 30/196 loss 1.026\n",
"epoch 5 batch 40/196 loss 1.200\n",
"epoch 5 batch 50/196 loss 1.168\n",
"epoch 5 batch 60/196 loss 1.281\n",
"epoch 5 batch 70/196 loss 1.193\n",
"epoch 5 batch 80/196 loss 1.234\n",
"epoch 5 batch 90/196 loss 1.276\n",
"epoch 5 batch 100/196 loss 1.134\n",
"epoch 5 batch 110/196 loss 1.184\n",
"epoch 5 batch 120/196 loss 1.187\n",
"epoch 5 batch 130/196 loss 1.267\n",
"epoch 5 batch 140/196 loss 1.373\n",
"epoch 5 batch 150/196 loss 1.151\n",
"epoch 5 batch 160/196 loss 1.205\n",
"epoch 5 batch 170/196 loss 1.179\n",
"epoch 5 batch 180/196 loss 1.210\n",
"epoch 5 batch 190/196 loss 1.424\n"
]
}
],
"source": [
"def train(epoch):\n",
" model.train() # Set the model to be in training mode\n",
" for batch_index, (inputs, targets) in enumerate(train_loader):\n",
" inputs, targets = Variable(inputs), Variable(targets)\n",
"\n",
" # Forward\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, targets)\n",
" if batch_index % 10 == 0:\n",
" print('epoch {} batch {}/{} loss {:.3f}'.format(\n",
" epoch, batch_index, len(train_loader), loss.data[0]))\n",
"\n",
" # Backward\n",
" optimizer.zero_grad() # Set parameter gradients to zero\n",
" loss.backward() # Compute (or accumulate, actually) parameter gradients\n",
" optimizer.step() # Update the parameters\n",
"\n",
"for epoch in range(1, 6):\n",
" train(epoch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## 5. What's next?\n",
"\n",
"We have sketched a simple framework for training CNNs. There are a few more functions yet to be completed.\n",
"\n",
" - Use gpu and cudnn\n",
" - Do validation after each epoch\n",
" - Adjust the learning rate\n",
" - Compute top-k accuracy\n",
" - Average the loss during each epoch\n",
" - More data augmentations\n",
" - Save to and resume from checkpoints\n",
"\n",
"Please check the official [examples](https://github.com/pytorch/examples) on MNIST and ImageNet for more details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment