Skip to content

Instantly share code, notes, and snippets.

@amqdn
Created March 7, 2019 19:56
Show Gist options
  • Save amqdn/211b84d93bf05becbba89ecbca2ba20c to your computer and use it in GitHub Desktop.
Save amqdn/211b84d93bf05becbba89ecbca2ba20c to your computer and use it in GitHub Desktop.
Implementing WideResNet from scratch using fast.ai - MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "## WideResNet - MNIST"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "https://arxiv.org/pdf/1605.07146.pdf"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "NOTE: This is a study. I am comprehending to the best of my ability, but there may be mistakes. \n\nI have to first start out by saying that the concept of \"widening\" confused me for quite some time. Even after looking through code implementations, it wasn't clear to me what \"widening\" was or how it was being achieved. After reviewing the paper and different implementations multiple times, I think I'm starting to understand. What follows is a discussion on understanding how to widen a ResNet. Feel free to skip down below for the implementation. \n\nThe first thing to know is that the basic structure of the network relies on our familiar ResBlocks:"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "![](https://cdn-images-1.medium.com/max/1200/1*ByrVJspW-TefwlH7OLxNkg.png)"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "https://arxiv.org/pdf/1512.03385.pdf"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As depicted above, each ResBlock consists of two convolutional layers (referred to as \"weight layer\" above) , one being passed into the other, and then the output of these two layers is then added to the original input $X$ to form a \"skip\" or \"identity\" connection. \n\nIf our first convolutional layer is $g_{1}(x)$ and our second is $g_{2}(x)$, we can think of a ResBlock as $g_{2}(g_{1}(X)) + X$, setting aside the intermediate ReLU activation for this discussion. Now, a given convolutional layer $g_{n}(x)$ is going to have a set of feature maps, and the number of feature maps in that layer is referred to as the number of \"channels.\" What is a feature map? It is the grid that describes our input. A 64x64 RGB image can be thought of as having 3 different feature maps (for each color channel) of size 64x64. \n\nWhen we pass that image into a typical model, the first convolutional layer $g_{1}(x)$ will usually reduce the size of the feature map by half, from 64x64 to 32x32 (we'll see later that WRNs do not do this for the first layer). In order to retain information from our input whilst reducing the size of the feature map, we correspondingly increase the number of channels; in other words, we typically increase the number of feature maps that we want the model to output whenever we make those feature maps smaller. If we start with an image of (3 channels, 64px, 64px), then our first convolution might output (8 channels, 32px, 32px). In a hypothetical scenario, we may want more than 8 channels of feature maps to work with at the start, in which case we may choose to have the convolution output (16 channels, 32px, 32px) instead. Choosing an appropriate number of starting channels is a topic for a different discussion. \n\nThe main thesis of the WideResNet paper is that increasing both the number of convolutional layers $g(x)$ and the number of channels will improve the performance of a given ResNet model, but that at some point, there is a significant diminishing return when increasing the depth of a model versus increasing its width. The more layers a model has, the deeper it is; the more channels there are within each layer, the wider the model. The authors argue that increasing the number of channels per layer yields better results and is more computationally efficient than simply increasing the number of layers, and they seem to show this from their experiments. \n\nGoing back to our previous example, if we wanted to make our model deeper, we would increase the number $n$ of our layers $g(x)$. If instead we wanted to make our model wider, we would not change the number of layers $n$, but we would increase the number of channels from 8 to 16, or 8 to 24. We can represent how much we want to widen each layer by using the multiplier $k$. If $k = 1$, then we have not widened our model. If $k = 2$, then the output from our first layer would have 16 channels instead of 8. If $k = 3$, then our layer would widen from 8 channels to 24, and so on."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we understand how to \"widen\" our model, we can proceed to build a WideResNet from scratch. Let's first prepare our data. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%reload_ext autoreload\n%autoreload 2\n%matplotlib inline",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai.vision import *",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "path = untar_data(URLs.MNIST)",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "path.ls()",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "[PosixPath('/home/jupyter/.fastai/data/mnist_png/testing'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/models')]"
},
"output_type": "execute_result",
"metadata": {},
"execution_count": 4
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "il = ImageList.from_folder(path, convert_mode='L')",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "defaults.cmap='binary'",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "il[0].show()",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "<Figure size 216x216 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABHRJREFUeJzt3b9LlW0cx3F9FIKWhpxscgocGgNBJGgIIZpyKneHQPBvaHUR4Sw2CxI4tERTv4YzNvQHHJRARBAHB5Go8Xn4nO+B+/QcT0fP6zV+uTleQ28uvLq8z+SvX78mgH/987cXAKNGFBBEAUEUEEQBQRQQRAFhesg/z3+KMEomq6GdAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCwrDf5sGIe/v2bTlfWVkp52tra+W81WoNbE3DZqeAIAoIooAgCgiigDA55O+88y7ZEbK3t9c1e/HiRfnsz58/y/ns7Gw5//LlSzmfm5truLqh8C5ZaEIUEEQBQRQQRAHB6dMYODk5Kef379/vmp2enpbPPnjwoJxvbm6W88ePHzdc3V/l9AmaEAUEUUAQBQRRQPCXd2Nge3u7nFcnTdPT9T+J169fl/NrcsrUFzsFBFFAEAUEUUAQBQR3n26Qjx8/lvPl5eVyfnFx0TV7+PBh+Wy73f7jdY0wd5+gCVFAEAUEUUAQBQR3n26QT58+lfPqlKmXp0+fDmo515adAoIoIIgCgiggiAKCu0/XUKfTKeePHj0q5wcHB+V8amqqa3Z2dlY+e/v27UZru2bcfYImRAFBFBBEAcE1j2tofX29nPf6hbqX/f39rtkN/YW6L3YKCKKAIAoIooAgCghOn0bYjx8/yvm3b9/6+pw7d+6U84WFhb7XNA7sFBBEAUEUEEQBQRQQnD6NgF5f6L66ulrODw8P+/r8V69elfO7d+/29Tnjwk4BQRQQRAFBFBBEAcErbkbA0dFROZ+dne3rcxYXF8v5+/fvy7m/svOKG2hEFBBEAUEUEEQBwd2nEbCzszOQz1laWirnTpn6Y6eAIAoIooAgCgiigODu0xB9/fq1nC8vL5fz8/Pzcn7v3r1y/vnz53I+NzfXYHVjyd0naEIUEEQBQRQQRAHB3ach+vDhQznvdcpUffn7xMTERKvVKudOmQbDTgFBFBBEAUEUEFzzuCLv3r3rmj1//rx89vLyspzPz8+X8+/fv//5wvgv1zygCVFAEAUEUUAQBQTXPK5I9ZVavU6Zenn27NmglkMf7BQQRAFBFBBEAUEUEJw+jbB2u13OT05OyvnMzMxVLmds2CkgiAKCKCCIAoIoIDh9+p+2trbK+fHxcePPuHXrVjnvdZrU6XT6ep7+2CkgiAKCKCCIAoIoIHjvU0O9vjrryZMn5fzi4qJr1uuFybu7u+W813uiGBjvfYImRAFBFBBEAUEUENx9amhhYaGcv3z5spy/efOma7axsVE+65RptNgpIIgCgiggiAKCKCC4+8Q4c/cJmhAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAWEYb9guXylCIwSOwUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFhN98HLeB0FHJmgAAAABJRU5ErkJggg==\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# split the data into training and validation sets\nsd = il.split_by_folder(train='training', valid='testing')",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "sd",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "ItemLists;\n\nTrain: ImageList (60000 items)\nImage (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\nPath: /home/jupyter/.fastai/data/mnist_png;\n\nValid: ImageList (10000 items)\nImage (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\nPath: /home/jupyter/.fastai/data/mnist_png;\n\nTest: None"
},
"output_type": "execute_result",
"metadata": {},
"execution_count": 9
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "(path/'training').ls()",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "[PosixPath('/home/jupyter/.fastai/data/mnist_png/training/1'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/6'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/8'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/2'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/5'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/9'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/0'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/3'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/7'),\n PosixPath('/home/jupyter/.fastai/data/mnist_png/training/4')]"
},
"output_type": "execute_result",
"metadata": {},
"execution_count": 10
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# label the data according to folders, as above\nll = sd.label_from_folder()",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "ll",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "LabelLists;\n\nTrain: LabelList (60000 items)\nx: ImageList\nImage (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\ny: CategoryList\n1,1,1,1,1\nPath: /home/jupyter/.fastai/data/mnist_png;\n\nValid: LabelList (10000 items)\nx: ImageList\nImage (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\ny: CategoryList\n1,1,1,1,1\nPath: /home/jupyter/.fastai/data/mnist_png;\n\nTest: None"
},
"output_type": "execute_result",
"metadata": {},
"execution_count": 12
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x,y = ll.train[0] # input and output (label)",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x.show()\nprint(y,x.shape)",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"text": "1 torch.Size([1, 28, 28])\n",
"output_type": "stream"
},
{
"data": {
"text/plain": "<Figure size 216x216 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABAtJREFUeJzt3e9JI0EAxuGNXhnWEdQ2RINWoZIqFBS78E8hfrALYxdBvY/HvZmR3Zisa/I8H4eFG5Cfw42zs6PPz88G+GfnpycAQyMKCKKAIAoIooAgCgiigPCn53/PH0UYklFp0EoBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAGh79s8GJDZbLYwNplMis8+Pz8Xx6fTaXH8+vp6+Yn9MCsFBFFAEAUEUUAQBYRRz9+8c5fsgBweHi6Mvby8FJ99f38vju/u7hbH5/P58hPrj7tkoQ1RQBAFBFFAEAUEZ5+2QO3cUml8NCpuyDR7e3vF8cfHx+UnNlBWCgiigCAKCKKAIAoIdp82SG2X6fT0tDhe2mmqnWUqnZNqmqbZ399vObvfw0oBQRQQRAFBFBBEAcHuU0u1nZ2ag4ODNc2k7u3trTj++vpaHC+9dVl7w248Hi8/sV/GSgFBFBBEAUEUEEQBwe5TS3d3d52e/4ndp5raeabSTlPt2cvLy5XOacisFBBEAUEUEEQBQRQQ7D61dH9/Xxy/ubnpeSb1c1gnJyfF8S53OW3iPU5dWSkgiAKCKCCIAoLPew1Y1ytrai8T1Y5uHB0dLYzVNhQ2lM97QRuigCAKCKKAIAoIjnkM2CqurGma+rU1Pe88/hpWCgiigCAKCKKAIAoIdp8G7Pb2tjje5cqar56vvXy07awUEEQBQRQQRAFBFBC8eTcAtTfsah90r+0a1X6Wtec/Pj5azG6jefMO2hAFBFFAEAUEUUBw9mkAap8Oq+0adT37NJ1Ol5vYlrJSQBAFBFFAEAUEUUBw9qlHT09PxfHJZFIc73qW6fj4uDj+8PDQYnZbydknaEMUEEQBQRQQRAHB2acBcI/TsFgpIIgCgiggiAKC/2j3qHaVTdfPb9WeH4/Hy02M/1gpIIgCgiggiAKCKCB4yahHOzvl30GrOuYxn8+Xm9j28pIRtCEKCKKAIAoIooDg7NM3zWaz4njp2pquZ5lqz5+fnxfHa1fo1NSuxNl2VgoIooAgCgiigCAKCM4+fVNtx+fs7GxhrOtZplU9X/u819XVVXF8izj7BG2IAoIoIIgCgiggOPu0JqWdoFWdfao9XzvLZJepGysFBFFAEAUEUUAQBQS7T2tSOp+07rNPFxcXLWfHV6wUEEQBQRQQRAFBFBC8ecc28+YdtCEKCKKAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCKKAIAoIfV+wXLxSBIbESgFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQB4S+1C9//ZZvC9gAAAABJRU5ErkJggg==\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tfms = ([*rand_pad(padding=3, size=28, mode='zeros')], []) # only minimal tfms",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "ll = ll.transform(tfms) # append the tfms",
"execution_count": 16,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "bs = 128",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = ll.databunch(bs=bs).normalize() # normalize accomplishes the mean/std normalization mentioned in the paper",
"execution_count": 18,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Data augmentation occurs on-the-fly when the object itself is called; thus, every time the `data.train_ds[0][0]` object below is called, it will produce a new version with tfms applied. `plot_multi` will then call that object multiple times, showing our tfms. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def _plot(i,j,ax): data.train_ds[0][0].show(ax, cmap='gray') # i is row, j is col, ax is the figsize\nplot_multi(_plot, 3, 3, figsize=(8,8))",
"execution_count": 19,
"outputs": [
{
"data": {
"text/plain": "<Figure size 576x576 with 9 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAHTCAYAAABiN8IeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEjxJREFUeJzt3VFuE0sWBuBqLllGthHF7COBZYQHsgqisApkRFgGEWEfWQb0PCA0cnURGqf9u+3+vre2SuPW3DP5b00dn+r6vi8AwO692PcLAMBSCF0ACBG6ABAidAEgROgCQIjQBYAQoQsAIS+TX9Z1nR8FL1jf990+vlfdLds+6k7NLdtTNWenCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAhL/f9Akt3enq68bxerwdrXr16tfH8/v37wZp3795N+2IATM5OFwBChC4AhAhdAAgRugAQ0vV9n/uyrst92YH4+vXrxvPZ2dlgzYsXm/9u9PPnz8Gak5OTaV9sB/q+7/bxvepu2fZRd2oup25GLWX/DalP1ZydLgCECF0ACBG6ABDiTDfo/Px88Nn9/f3Gc+ufx+Pj48bz5eXlYM23b9+e+Xa750x3PpY0lMWZ7nGr+2JK2X9vjDNdAJgBoQsAIUIXAEKELgCEaKTakVbTVKtZpW5oaR3uf/nyZeP5zZs3z3y7/dBINR+GsuyWmtud+m9r3Yxayv4bUjVSAcAMCF0ACBG6ABAidAEg5OW+X2BXWo1MLbua5NS6+aL1WddtnrfXzSullPLw8DDdi7E4rf8trFarjedtG09gl8Y0pLZqt9X0Vzdc7WuKn50uAIQIXQAIEboAEHK0Z7pXV1ej1iUHTbTOGcYMILi9vd3ZO3Fcxg5lqc/B5nwGxv6M6Y3ZZV2M6Y2p+2JKmXdvjJ0uAIQIXQAIEboAECJ0ASDkaBupWg1Sb9++3dn31Q0Hnz59GqwxgOD4GcrCMRnTkJq+9axu+mvV7pwbUu10ASBE6AJAiNAFgJCjPdNtmer/0z/GIdxMw1AWjskh9MbUfTGlzLs3xk4XAEKELgCECF0ACBG6ABCyqEaqqUx580VrHYfrEBpPSjGUhe3NrSG1bkYtZd4NqXa6ABAidAEgROgCQIgz3S20BiBsO4S7dWbBcZnbGVgphrKwf1P1xhxaX4ydLgCECF0ACBG6ABAidAEgRCPVCHUDy2q1GqypG1jGDsdI3jbDYTOUhWMyVUPqoTWj2ukCQIjQBYAQoQsAIUIXAEI0Uo1QH/iPmfrTagC4ubmZ9sVYFJPQOGS7akg9tGZUO10ACBG6ABAidAEgxJlu5eLi4q+fjTlnuLu7G6y5vr5+5tuxJIaycEz0xvxipwsAIUIXAEKELgCECF0ACNFINcISbr5gfjSecKg0pP6ZnS4AhAhdAAgRugAQ4ky3Ug8kKGV4zjDmLOLh4WHaF+OoOQPj2OmN+cVOFwBChC4AhAhdAAgRugAQopGqUg8kKGW7BoDb29tpX4zF0XjCodKQ+md2ugAQInQBIEToAkDIos50T09PB5+t1+uN5zHnDK01Hz582HhuDTto+fz586h1HDdnYBwTvTF/ZqcLACFCFwBChC4AhAhdAAjpkj+m77pur7/cbzU3ffz4ceN5zOH+tmtubm4Gny3pBpi+74edQAH7rrsxfvz4Mfhsm7o7OTmZ9sWOwD7q7hBqblt1Q2rdjFpKKavVavBZnTWtxsD6b+T379//+j5zbEZ9qubsdAEgROgCQIjQBYAQoQsAIYuaSNWyzdSfMWvu7u4Ga5bUNMX/mYTGMamnp52dnQ3WtBp0xzQG1pOsxjSkHlot2+kCQIjQBYAQoQsAIYs/051q8MUSbsdgO60bhOpzsOQZWCmHdw7GfLVqTm/Mn9npAkCI0AWAEKELACFCFwBCFtVI1Woe0VDCPhjKwrGoG/5K0ZD6FDtdAAgRugAQInQBIGRRZ7owF4aycKjqPhh9Mf/GThcAQoQuAIQIXQAIEboAENK1bjfZ2Zd1Xe7LmJ2+74fTHQLU3bLto+7U3LI9VXN2ugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAiJDscAgCWz0wWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkDIy+SXdV3XJ7+Peen7vtvH96q7ZdtH3am5ZXuq5ux0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIS83PcLsBunp6eDz9br9eCzV69ebTy/f/9+sObdu3fTvRhHpa6zMTVWyrDO1BilbFdPh/Y3y04XAEKELgCECF0ACOn6vs99Wdflvmzhvn79Ovjs7Oxs8NmLF5v/3vXz58/BmpOTk0neqe/7bpL/oH+k7nanrrMxNVbKsM6mqrGWfdSdmtvONvW0y79Z23qq5ux0ASBE6AJAiNAFgBChCwAhhmMcifPz843n1Wo1WNNqmnt8fNx4vry8nPbFOBp1jZUyrLMxNVaKOmO6ejq0WrLTBYAQoQsAIUIXAEKELgCEaKQ6QK0GhPo2jlYDQmtyy/39/cbzt2/fnvl2HIu6zlo3vtR1NqbGSlFnS7Srejq0WrLTBYAQoQsAIUIXAEKc6W6hdaZa2+U5w+np6V8/67rhJRet214eHh6mezEmNbc6a9VdXWdqbL7U0zzY6QJAiNAFgBChCwAhQhcAQjRSbeHq6uqva968eRN4k/+rf0TeakBo/dD89vZ2Z+/E88ytzlr1U9eZGpsv9TQPdroAECJ0ASBE6AJAiDPdLbTOPd6+fbuz76t/1P7p06fBmnpQ+OPj42DN5eXltC/GTtV1lqyxUoZ11rpEo64zNTZf6mke7HQBIEToAkCI0AWAEKELACFd6zB7Z1/WdbkvO1CtBoT1er3x3Lqdo/4R+ZcvXwZr0gM7an3fD68+ClB3Q3Wd1TVWyrDOWoMK6jrbd4217KPullRzY/5mlXI89TTGUzVnpwsAIUIXAEKELgCEGI4xM63z2vqzrhseF9SDwltr4Le6plp1V9dQ6xINdcaYv1mlqKff7HQBIEToAkCI0AWAEKELACEaqWbm6upq8Fn9I/JWA0K9Jjn0hMNT11lrUEFdZ6016owxf7NKUU+/2ekCQIjQBYAQoQsAIUIXAELcMrRn9Q0d9/f3gzX1P6PW1JZ6zX///TfB203LLUP70boFpq6z1t+Bus5aa+ZYZzW3DE1rm79ZpRxPPY3hliEAmAGhCwAhQhcAQgzH2LP6h+Wtc44xwzFubm6mfTGORmt4QV1nY4YZqDFK2e5vVinq6Tc7XQAIEboAECJ0ASBE6AJAiOEYQRcXF4PP1uv1xvOYwRd3d3eDNa9fv37m2+2e4Ri7N6bGShk3qKCus0OosRbDMbY31d+sUo6nnsYwHAMAZkDoAkCI0AWAEMMx9mzM4It6TfIcnsM3ZlBBa406o2Wbv1mlqKff7HQBIEToAkCI0AWAEKELACEaqYLOz88Hn9VNCK0fmtdrHh4epn0xjsaYGitlWGetNeqMqf5mlaKefrPTBYAQoQsAIUIXAEJceBD048ePwWfb/ND85ORk2hcLceHB7o2psVLGDcc41DqrufBge1P9zSrleOppDBceAMAMCF0ACBG6ABAidAEgxHCMiZyenm48r9frwZoxPyJvrfnw4cPG88XFxV/f5/Pnz39dw2Gpa6yUYZ2NHVRQr6trrBR1duymqqcxf7NKUU+/2ekCQIjQBYAQoQsAIUIXAEJMpJpI3STw8ePHwZoxk1u2XXNzc7PxfH19/eeX3RMTqZ6n1YhS19nY6UBjJlLVa+oaK2WedVYzkaptqnratuYOtZ7GMJEKAGZA6AJAiNAFgBDDMXZkzECC1roxa+7u7gZrjuUshH+zTf201rXW1HWmxo7fNvU0tubU0y92ugAQInQBIEToAkCI0AWAEI1UOzLmx+GtdWPW3N7ePvPtOBZTDSporVFnyzPVsB719Gd2ugAQInQBIEToAkCICw+IceEB++DCA9JceAAAMyB0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAh0VuGAGDJ7HQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIORl8su6ruuT38e89H3f7eN71d2y7aPu1NyyPVVzdroAECJ0ASBE6AJAiNAFgBChCwAhQhcAQoQuAIQIXQAIEboAECJ0ASBE6AJAiNAFgBChCwAh0VuG+LvT09PBZ+v1euP51atXgzXv37/feH737t20LwbAs9npAkCI0AWAEKELACFd3/e5L+u63JcdqK9fvw4+Ozs723h+8WL470o/f/7ceD45OZn2xSbQ9323j+9Vd8u2j7pTc8v2VM3Z6QJAiNAFgBChCwAhQhcAQgzH2LPz8/ON59VqNVhTN7s9Pj4O1lxeXk77YlCZanBLKYa3sFx2ugAQInQBIEToAkCI4RhB9fltKcMzsda5WT344suXL4M1b968eebb7Z7hGIdtqsEtpWSHtxiOMU/HfLmL4RgAMANCFwBChC4AhAhdAAhZ1HCMViNT7du3bzv7/lbjQP1Z1w3P3+vmlIeHh2lfDBoMbmGX6qapUoaNea0mvKurq43nOTZSPcVOFwBChC4AhAhdAAgRugAQsqhGqvoAviU92aluFBgz0ef29nan78TyjJmW1ppeV9fm/f39YM0umxM5HBrzfrHTBYAQoQsAIUIXAEIWdabbOq99+/btTr6rdUb26dOnwWdLOMPg7wxu4ZjoEfgzO10ACBG6ABAidAEgROgCQMiiGqlapho0UTcOtG7QWGrjAH9ncAtT0pg3X3a6ABAidAEgROgCQMjiz3SnUp9XtM40xpxhtNZw/JKDW0oZnvkZ3HJc9AjMl50uAIQIXQAIEboAECJ0ASBEI9VE6saFuiGglHGNA60BGizTrga3lOLGl2PnRrX5stMFgBChCwAhQhcAQpzpbqF1hrFarTaeW2dkY4ZjpH+wzvGbavi8wS2HzeUu82CnCwAhQhcAQoQuAIQIXQAI0Ui1hdYNHnXjwNjhGDc3N9O9GDS06nWbG18MbqEUN6o9l50uAIQIXQAIEboAECJ0ASBEI9UIFxcXTz6XMmwKaDWm3N3dDT67vr5+5tvB/42ZllbKsCnKtDTGcqPa89jpAkCI0AWAEKELACHOdLcw5gyjtWapZxjkjBncUsq44RgGt+BGtenZ6QJAiNAFgBChCwAhQhcAQjRSjVA3E7SaTsYMx3h4eJj2xVi8bQa3lDKsT4NbaHGj2vTsdAEgROgCQIjQBYAQZ7ojbDPgu7Xm9vZ22heDiuHzPIfLXXbPThcAQoQuAIQIXQAIEboAELL4RqrT09ON5/V6PVgzpnGgXvPhw4fBmlZTQu3z589/XQO/bTO4pbXO4BZa3Kg2PTtdAAgRugAQInQBIGTxZ7r1mdjZ2dlgzTYDvluDwscMAXemy7/YZnBLa53BLZTicpcEO10ACBG6ABAidAEgROgCQEiX/BFz13Wz+8V0PbBizHCM1n9nY9a0bt54/fr1qPc8Bn3fD6c0BMyx7sYYM7hltVptPI+pzVKGDXzfv3//6/scapPfPuruUGvux48fG8/bDsc4OTmZ9sUOzFM1Z6cLACFCFwBChC4AhAhdAAhZ/ESq2raNA2PWmPrDv9jVtLRShpOsTEs7fm5Umwc7XQAIEboAECJ0ASBk8cMxyDEc49/sanBLa90xD24xHOOXup4+fvw4WDNV/8qYHoHr6+s/v+yBMxwDAGZA6AJAiNAFgBChCwAhhmPAgZiyqaVeZ3DL8owZfLHtmlZj3jE3Tv0LO10ACBG6ABAidAEgxJkuzFQ9EH6pA+LZDZe77IedLgCECF0ACBG6ABAidAEgxC1DxLhliH1wyxBpbhkCgBkQugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhESHYwDAktnpAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQInQBIEToAkCI0AWAEKELACFCFwBChC4AhAhdAAgRugAQ8j+OrwjIjCliywAAAABJRU5ErkJggg==\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data.show_batch(rows=3, figsize=(5,5))",
"execution_count": 24,
"outputs": [
{
"data": {
"text/plain": "<Figure size 360x360 with 9 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "---"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Implementing WideResNet"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "https://github.com/szagoruyko/wide-residual-networks/blob/master/models/wide-resnet.lua\n\nThere are two different WRNs from the paper that seem to offer the best performance and speed: WRN-28-10 and WRN-50-2-bottleneck, the former on CIFAR and the latter on ImageNet. For our purposes, let's start with WRN-28-10. \n\nThe experimenters point out that having 3x3 kernels throughout the network as well as having only two convolutional layers in each ResBlock results in the best performance. Further, the researchers indicate that placing a dropout layer between each of the convolutional layers in a ResBlock results in better performance. We will build our network using this information. \n\nAbout the WRN-$n$-$k$ notation: $n$ is the number of convolutions occurring in our network in total. If we have a WRN-28-10, what exactly does $n = 28$ total layers mean? Let's count our convolutions: A ResBlock has 2 convolutions occurring inside it. Remember from the paper that there are three \"conv groups\" that follow the initial convolution, each with a count of $N$ ResBlocks. If $N = 4$, each conv group will have 4 ResBlocks, or 4 x 2 convolutions, for a total of 8 per group. That gives us 24 convolutions; where are the other 4? At this point, I can only surmise from the authors' code that they are counting the \"identity\" convolutions in each conv group as 1 convolution per group (even though, technically, each ResBlock has its own identity convolution occurring...), so 24 + 3 gives us 27 convolutions. Finally, we add our initial convolution (\"conv1\") to the total for a depth of $n = 28$. That's the most sense I can make of it. "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# WRN-28-10, 28 layers with widening factor 10\n# N is (depth - 4) / 6, so (28 - 4) / 6 = 4\n\ndef wresgroup(ni, nf, k=10, N=4, stride=1, dropout=0.):\n layers = []\n nfk = nf * k\n for i in range(N):\n if i == 0:\n layers.append(WResBlock(ni, nfk, stride=stride, dropout=dropout))\n else:\n layers.append(WResBlock(nfk, nfk, dropout=dropout))\n return layers\n\nclass WResBlock(nn.Module):\n def __init__(self, ni, nfk, stride=1, dropout=0.):\n super().__init__()\n layers = [\n nn.BatchNorm2d(ni),\n nn.ReLU(inplace=True),\n nn.Conv2d(ni, nfk, 3, stride=stride, padding=1, bias=False),\n nn.BatchNorm2d(nfk),\n nn.ReLU(inplace=True),\n nn.Dropout(dropout),\n nn.Conv2d(nfk, nfk, 3, stride=1, padding=1, bias=False)\n ]\n self.conv_block = nn.Sequential(*layers)\n self.identity = nn.Sequential(nn.Conv2d(ni, nfk, 1, stride=stride, padding=0, bias=False))\n\n def forward(self, x): return self.identity(x) + self.conv_block(x)",
"execution_count": 20,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model = nn.Sequential(\n nn.Conv2d(1, 8, 3, padding=1, bias=False), # conv1 -- we'll start with 8 features because the images are so small\n *wresgroup(8, 8, dropout=0.3), # conv2\n *wresgroup(80, 16, stride=2, dropout=0.3), # conv3 -- more features, half the grid size\n *wresgroup(160, 32, stride=2, dropout=0.3), # conv4\n nn.BatchNorm2d(320),\n nn.ReLU(inplace=True),\n nn.AvgPool2d(7), # the output of the last wresgroup is 7x7\n Flatten(),\n nn.Linear(320, 10)\n)",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "print(learn.summary())",
"execution_count": 23,
"outputs": [
{
"name": "stdout",
"text": "======================================================================\nLayer (type) Output Shape Param # Trainable \n======================================================================\nConv2d [1, 8, 28, 28] 72 True \n______________________________________________________________________\nBatchNorm2d [1, 8, 28, 28] 16 True \n______________________________________________________________________\nReLU [1, 8, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 5,760 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 640 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 6,400 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 6,400 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 57,600 True \n______________________________________________________________________\nConv2d [1, 80, 28, 28] 6,400 True \n______________________________________________________________________\nBatchNorm2d [1, 80, 28, 28] 160 True \n______________________________________________________________________\nReLU [1, 80, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 115,200 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 12,800 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 25,600 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 25,600 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 230,400 True \n______________________________________________________________________\nConv2d [1, 160, 14, 14] 25,600 True \n______________________________________________________________________\nBatchNorm2d [1, 160, 14, 14] 320 True \n______________________________________________________________________\nReLU [1, 160, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 460,800 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 51,200 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 102,400 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 102,400 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 320, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 921,600 True \n______________________________________________________________________\nConv2d [1, 320, 7, 7] 102,400 True \n______________________________________________________________________\nBatchNorm2d [1, 320, 7, 7] 640 True \n______________________________________________________________________\nReLU [1, 320, 7, 7] 0 False \n______________________________________________________________________\nAvgPool2d [1, 320, 1, 1] 0 False \n______________________________________________________________________\nFlatten [1, 320] 0 False \n______________________________________________________________________\nLinear [1, 10] 3,210 True \n______________________________________________________________________\n\nTotal params: 9,529,058\nTotal trainable params: 9,529,058\nTotal non-trainable params: 0\n\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.lr_find()\nlearn.recorder.plot()",
"execution_count": 24,
"outputs": [
{
"data": {
"text/html": "",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
},
{
"name": "stdout",
"text": "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n",
"output_type": "stream"
},
{
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(5, 5e-3)",
"execution_count": 25,
"outputs": [
{
"data": {
"text/html": "Total time: 05:30 <p><table style='width:375px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n <th>time</th>\n </tr>\n <tr>\n <th>0</th>\n <th>0.145045</th>\n <th>1.070395</th>\n <th>0.736500</th>\n <th>01:06</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.100319</th>\n <th>0.155969</th>\n <th>0.943800</th>\n <th>01:06</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.061960</th>\n <th>0.049387</th>\n <th>0.982200</th>\n <th>01:06</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.032360</th>\n <th>0.024816</th>\n <th>0.992700</th>\n <th>01:06</th>\n </tr>\n <tr>\n <th>4</th>\n <th>0.016214</th>\n <th>0.010952</th>\n <th>0.995900</th>\n <th>01:06</th>\n </tr>\n</table>\n",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "terp = learn.interpret()\nterp.plot_top_losses(9, figsize=(7,7))",
"execution_count": 26,
"outputs": [
{
"data": {
"text/plain": "<Figure size 504x504 with 9 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Trying WRN-40-2"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "9.5M parameters is probably overkill for MNIST. Let's build a more reasonable network instead:"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# WRN-40-2\nmodel = nn.Sequential(\n nn.Conv2d(1, 8, 3, padding=1, bias=False), # conv1\n *wresgroup(8, 8, k=2, N=6, dropout=0.3), # conv2\n *wresgroup(16, 16, k=2, N=6, stride=2, dropout=0.3), # conv3\n *wresgroup(32, 32, k=2, N=6, stride=2, dropout=0.3), # conv4\n nn.BatchNorm2d(64),\n nn.ReLU(inplace=True),\n nn.AvgPool2d(7), # the output of the last wresgroup is 7x7\n Flatten(),\n nn.Linear(64, 10)\n)",
"execution_count": 27,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)",
"execution_count": 28,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "print(learn.summary())",
"execution_count": 29,
"outputs": [
{
"name": "stdout",
"text": "======================================================================\nLayer (type) Output Shape Param # Trainable \n======================================================================\nConv2d [1, 8, 28, 28] 72 True \n______________________________________________________________________\nBatchNorm2d [1, 8, 28, 28] 16 True \n______________________________________________________________________\nReLU [1, 8, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 1,152 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 128 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 256 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 256 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 256 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 256 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nDropout [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 2,304 True \n______________________________________________________________________\nConv2d [1, 16, 28, 28] 256 True \n______________________________________________________________________\nBatchNorm2d [1, 16, 28, 28] 32 True \n______________________________________________________________________\nReLU [1, 16, 28, 28] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 4,608 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 512 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 1,024 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 1,024 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 1,024 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 1,024 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nDropout [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 9,216 True \n______________________________________________________________________\nConv2d [1, 32, 14, 14] 1,024 True \n______________________________________________________________________\nBatchNorm2d [1, 32, 14, 14] 64 True \n______________________________________________________________________\nReLU [1, 32, 14, 14] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 18,432 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 2,048 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 4,096 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 4,096 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 4,096 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 4,096 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nDropout [1, 64, 7, 7] 0 False \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 36,864 True \n______________________________________________________________________\nConv2d [1, 64, 7, 7] 4,096 True \n______________________________________________________________________\nBatchNorm2d [1, 64, 7, 7] 128 True \n______________________________________________________________________\nReLU [1, 64, 7, 7] 0 False \n______________________________________________________________________\nAvgPool2d [1, 64, 1, 1] 0 False \n______________________________________________________________________\nFlatten [1, 64] 0 False \n______________________________________________________________________\nLinear [1, 10] 650 True \n______________________________________________________________________\n\nTotal params: 589,410\nTotal trainable params: 589,410\nTotal non-trainable params: 0\n\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.lr_find()\nlearn.recorder.plot()",
"execution_count": 30,
"outputs": [
{
"data": {
"text/html": "",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
},
{
"name": "stdout",
"text": "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n",
"output_type": "stream"
},
{
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(12, 1e-2)",
"execution_count": 31,
"outputs": [
{
"data": {
"text/html": "Total time: 04:48 <p><table style='width:375px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n <th>time</th>\n </tr>\n <tr>\n <th>0</th>\n <th>0.233642</th>\n <th>0.573115</th>\n <th>0.823400</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.135311</th>\n <th>0.286319</th>\n <th>0.918000</th>\n <th>00:24</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.115002</th>\n <th>0.124288</th>\n <th>0.963800</th>\n <th>00:24</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.096996</th>\n <th>0.064454</th>\n <th>0.980400</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>4</th>\n <th>0.078536</th>\n <th>0.173576</th>\n <th>0.953800</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>5</th>\n <th>0.064685</th>\n <th>0.165209</th>\n <th>0.950100</th>\n <th>00:24</th>\n </tr>\n <tr>\n <th>6</th>\n <th>0.050055</th>\n <th>0.037562</th>\n <th>0.987500</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>7</th>\n <th>0.041066</th>\n <th>0.027783</th>\n <th>0.991500</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>8</th>\n <th>0.029880</th>\n <th>0.017404</th>\n <th>0.994600</th>\n <th>00:23</th>\n </tr>\n <tr>\n <th>9</th>\n <th>0.020542</th>\n <th>0.014314</th>\n <th>0.994400</th>\n <th>00:24</th>\n </tr>\n <tr>\n <th>10</th>\n <th>0.015938</th>\n <th>0.012311</th>\n <th>0.996000</th>\n <th>00:24</th>\n </tr>\n <tr>\n <th>11</th>\n <th>0.016676</th>\n <th>0.011243</th>\n <th>0.996300</th>\n <th>00:23</th>\n </tr>\n</table>\n",
"text/plain": "<IPython.core.display.HTML object>"
},
"output_type": "display_data",
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "terp = learn.interpret()\nterp.plot_top_losses(9, figsize=(7,7))",
"execution_count": 32,
"outputs": [
{
"data": {
"text/plain": "<Figure size 504x504 with 9 Axes>",
"image/png": "\n"
},
"output_type": "display_data",
"metadata": {
"needs_background": "light"
}
}
]
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"pygments_lexer": "ipython3",
"version": "3.7.1",
"nbconvert_exporter": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"name": "python",
"mimetype": "text/x-python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Implementing WideResNet from scratch using fast.ai - MNIST",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment