Skip to content

Instantly share code, notes, and snippets.

@NTT123
Created July 21, 2018 05:50
Show Gist options
  • Save NTT123/37b65b2139375ca544267e00d2578648 to your computer and use it in GitHub Desktop.
Save NTT123/37b65b2139375ca544267e00d2578648 to your computer and use it in GitHub Desktop.
ResNet.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ResNet.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/gist/NTT123/37b65b2139375ca544267e00d2578648/resnet.ipynb)"
]
},
{
"metadata": {
"id": "o9_uKhBn-bqf",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Deep Residual Learning for Image Recognition\n",
"Kaiming He Xiangyu Zhang Shaoqing Ren Jian Sun"
]
},
{
"metadata": {
"id": "JL9i9zB2k8Hb",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# <img width=\"700\" src=\"\" alt=\"\" />"
]
},
{
"metadata": {
"id": "-uNfp8m3EVsa",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"#@title Load CIFAR-10 dataset\n",
"# http://pytorch.org/\n",
"from os import path\n",
"from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
"platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
"\n",
"accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n",
"\n",
"!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision\n",
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=32,\n",
" shuffle=False, num_workers=2)\n",
"\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from pylab import rcParams\n",
"rcParams['font.size'] = 14\n",
"plt.style.use('seaborn-white')\n",
"rcParams['figure.figsize'] = 8, 4\n",
"\n",
"# functions to show an image\n",
"\n",
"\n",
"def imshow(img):\n",
" img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.axis(\"off\")\n",
"\n",
"\n",
"# get some random training images\n",
"dataiter = iter(testloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"# show images\n",
"imshow(torchvision.utils.make_grid(images))\n",
"# print labels\n",
"print(' '.join('%5s' % classes[labels[j]] for j in range(8)))\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Sl_bpqCzEe2g",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from torch import nn"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "MnkbW4AhGRZB",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"device = torch.device(\"cuda\")"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "pBfMX_uQGn5k",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"## Full pre-activation\n",
"\n",
"class ResBlock(nn.Module):\n",
" def __init__(self, filters, downsampling=False):\n",
" super().__init__()\n",
" conv = []\n",
" nfirstFilters = filters\n",
" firstStride = 1\n",
" self.downsampling = downsampling\n",
" \n",
" if downsampling:\n",
" nfirstFilters = filters // 2\n",
" firstStride = 2\n",
" self.pad = nn.Sequential(\n",
" nn.Conv2d(in_channels=nfirstFilters, \n",
" out_channels=filters, \n",
" kernel_size=2, stride=2, padding=0, bias=False),\n",
" nn.BatchNorm2d(filters)\n",
" )\n",
"\n",
" self.F = nn.Sequential(\n",
" nn.BatchNorm2d(nfirstFilters),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels=nfirstFilters, out_channels=filters, \n",
" kernel_size=3, stride=firstStride, padding=1,\n",
" bias=False),\n",
" nn.BatchNorm2d(filters),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels=filters, out_channels=filters, \n",
" kernel_size=3, padding=1, bias=False),\n",
" )\n",
" \n",
" \n",
" #self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" out = self.F(x)\n",
" #print(x.shape)\n",
" #print(out.shape)\n",
" if self.downsampling:\n",
" x = self.pad(x)\n",
" #print(x.shape)\n",
" \n",
" return out + x # self.relu(out + x) "
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "JCcxERE7H4Ry",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "r0vHLgCLGGqw",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"class ResNet(nn.Module):\n",
" def __init__(self, numResBlock=2):\n",
" super().__init__()\n",
" self.F = nn.Sequential()\n",
" \n",
" self.F.add_module(\"conv1\", nn.Conv2d(in_channels=3, out_channels=16, \n",
" kernel_size=3, padding=1))\n",
" self.F.add_module(\"relu1\", nn.ReLU())\n",
" for i in range(numResBlock):\n",
" self.F.add_module(\"res16_{}\".format(i+1), ResBlock(16) )\n",
"\n",
" for i in range(numResBlock):\n",
" if i == 0:\n",
" self.F.add_module(\"res32_{}\".format(i+1), ResBlock(32, downsampling=True) )\n",
" else:\n",
" self.F.add_module(\"res32_{}\".format(i+1), ResBlock(32) )\n",
"\n",
" for i in range(numResBlock):\n",
" if i == 0:\n",
" self.F.add_module(\"res64_{}\".format(i+1), ResBlock(64, downsampling=True) )\n",
" else:\n",
" self.F.add_module(\"res64_{}\".format(i+1), ResBlock(64) )\n",
" \n",
" self.F.add_module(\"reluend\", nn.ReLU())\n",
" self.F.add_module(\"avgpool\", nn.AvgPool2d(kernel_size=8))\n",
" self.fc = nn.Linear(64, 10)\n",
" #self.softmax = nn.Softmax(dim=1)\n",
"\n",
" #for m in self.modules():\n",
" # if isinstance(m, nn.Conv2d):\n",
" # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
" # elif isinstance(m, nn.BatchNorm2d):\n",
" # nn.init.constant_(m.weight, 1)\n",
" # nn.init.constant_(m.bias, 0)\n",
" \n",
" def forward(self, x):\n",
" out = self.F(x).view(x.size(0), -1)\n",
" #print(out.shape)\n",
" out = self.fc(out)\n",
" return out\n",
" #return self.softmax(out) \n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "nnTasgqjGN3B",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"net = ResNet(numResBlock=4)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "BhdRSo7MGY7p",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"net = net.to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "IV6wxlchWGtw",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"nepoch = 1000\n",
"start = 1"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ZIOoHiWfWoT4",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"lossfn = torch.nn.CrossEntropyLoss()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "NnlVRcfijZMW",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"#@title\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import HTML\n",
"from matplotlib import animation\n",
"from IPython import display\n",
"\n",
"def test(net):\n",
" net.eval()\n",
" # get some random training images\n",
" dataiter = iter(testloader) \n",
" images, labels = dataiter.next()\n",
" _, indices = torch.max(net(images.to(device)), dim=1)\n",
" indices = indices.detach().cpu()\n",
" # show images\n",
" imshow(torchvision.utils.make_grid(images))\n",
" # print labels\n",
" display.clear_output(wait=True)\n",
" display.display(plt.gcf())\n",
" print(\"target :\", ' '.join('%5s' % classes[labels[j]] for j in range(8)))\n",
" print(\"predict:\", ' '.join('%5s' % classes[indices[j]] for j in range(8)))\n",
" \n",
" net.train()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "lsAW1oERD-LP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "c6b778b1-976c-4dcb-e72b-6da2a4830e72"
},
"cell_type": "code",
"source": [
"testset.test_data.shape"
],
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(10000, 32, 32, 3)"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"metadata": {
"id": "qrw1byndH89m",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"\n",
"def trainerr(net):\n",
" net.eval()\n",
" testloader1 = torch.utils.data.DataLoader(trainset, batch_size=1000,\n",
" shuffle=False, num_workers=0)\n",
" \n",
" s = 0.0\n",
" for ii, (images,labels) in enumerate(testloader1, 0): \n",
" _, indices = torch.max(net(images.to(device)), dim=1)\n",
" indices = indices.detach().cpu()\n",
" s = s + torch.sum((labels == indices).float()) \n",
" \n",
" print(\"test mis: \", (1-s / 50000)*100)\n",
" net.train()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "tpo7OS_DlRbj",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"\n",
"def testerr(net):\n",
" net.eval()\n",
" testloader1 = torch.utils.data.DataLoader(testset, batch_size=1000,\n",
" shuffle=False, num_workers=0)\n",
" \n",
" s = 0.0\n",
" for ii, (images,labels) in enumerate(testloader1, 0): \n",
" _, indices = torch.max(net(images.to(device)), dim=1)\n",
" indices = indices.detach().cpu()\n",
" s = s + torch.sum((labels == indices).float()) \n",
" \n",
" print(\"test mis: \", (1-s / 10000)*100)\n",
" net.train()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "TA2FyZvnDvy1",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"##testerr(net)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "eQb1Sej6Rr0a",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "1399865b-4c5f-4af0-d7d5-732390d7e6a6"
},
"cell_type": "code",
"source": [
"trans = transforms.Compose(\n",
" [\n",
" transforms.RandomCrop(32, padding=4),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
" ]\n",
")\n",
"\n",
"\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
" download=True, transform=trans)\n",
"\n",
"\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,\n",
" shuffle=True, num_workers=2)\n",
"\n"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "-LofpOq-WoAp",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"optimizer = torch.optim.SGD(net.parameters(), lr=1e-1, weight_decay=1e-4, momentum=0.9, )\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "mDsKX920WJjR",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"los = None\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=8,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"for epoch in range(start, nepoch+1):\n",
" \n",
" trainloader=torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=1)\n",
" scheduler.step()\n",
" print(scheduler.last_epoch, \" lr \", scheduler.get_lr())\n",
" for ii, (X,targetY) in enumerate(trainloader, 0):\n",
"\n",
" y = net(X.to(device))\n",
"\n",
" loss = lossfn(y, targetY.to(device))\n",
" optimizer.zero_grad()\n",
"\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" if los == None:\n",
" los = loss.item()\n",
" else:\n",
" los = los * 0.99 + (1-0.99)*loss.item()\n",
" test(net)\n",
" print(\"epoch: {} loss: {}\".format( epoch ,los))\n",
" if epoch % 3 == 0:\n",
" testerr(net)\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "sl91jgagnkcZ",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "NSOIOY9bDJfV",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "c_o2VSg0DKFB",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment