Skip to content

Instantly share code, notes, and snippets.

@SOVIETIC-BOSS88
Last active March 16, 2020 19:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SOVIETIC-BOSS88/1a9fcf31f9d17d756b930e71fec1079b to your computer and use it in GitHub Desktop.
Save SOVIETIC-BOSS88/1a9fcf31f9d17d756b930e71fec1079b to your computer and use it in GitHub Desktop.
FAST.AI JOURNEY: COURSE V3. PART 1. LESSON 5. Documenting my fast.ai journey: CODE REVIEW. PYTORCH DEEP DIVE PROJECT. TORCH.NN.MODULES.LINEAR CLASS.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# FAST AI JOURNEY: COURSE V3. PART 1. LESSON 5.\n",
"## Documenting my fast.ai journey: CODE REVIEW. PYTORCH DEEP DIVE PROJECT. TORCH.NN.MODULES.LINEAR CLASS."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this new project, we will dive deeper into the implement the torch.nn.modules.linear class. We will use the [Official PyTorch Documentation](https://pytorch.org/docs/stable/index.html) as a guide, more than I would like to admit, and cover some concepts we have learned during class.\n",
"\n",
"Every notebook starts with the following three lines; they ensure that any edits to libraries you make are reloaded here automatically, and also that any charts or images displayed are shown in this notebook."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from fastai import *"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#First we will import the following libraries. \n",
"import math\n",
"\n",
"import torch\n",
"from torch.nn.parameter import Parameter\n",
"\n",
"#from .. import functional as F\n",
"#from .module import Module\n",
"\n",
"from torch.nn import functional as F\n",
"from torch.nn.modules.module import Module "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MNIST SGD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the 'pickled' MNIST dataset from http://deeplearning.net/data/mnist/mnist.pkl.gz. We're going to treat it as a standard flat dataset with fully connected layers, rather than using a CNN."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"path = Path('data/mnist')\n",
"dest = path\n",
"dest.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('data/mnist/mnist.pkl.gz')]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"with gzip.open(path/'mnist.pkl.gz', 'rb') as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(50000, 784)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnEYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKIWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPRDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8HoInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4y5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XVtDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XUU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YANEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYffzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enTpyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9YceeihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+ppDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlAMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCapWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urVq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23JOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeHh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6kvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/Pll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7KrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFrkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oya9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X57LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbSu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5JecvdrJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5kk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsaG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nkk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93V6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHEE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kfGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+QzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjVhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHkquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2u/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2jR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5jZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8PoCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynDzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCzdKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710tM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXyvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+IiSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(x_train[0].reshape((28,28)), cmap=\"gray\")\n",
"x_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([50000, 784]), tensor(0), tensor(9))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train,y_train,x_valid,y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid))\n",
"n,c = x_train.shape\n",
"x_train.shape, y_train.min(), y_train.max()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In lesson2-sgd we did these things ourselves:\n",
"\n",
"```python\n",
"x = torch.ones(n,2) \n",
"def mse(y_hat, y): return ((y_hat-y)**2).mean()\n",
"y_hat = x@a\n",
"```\n",
"\n",
"Now instead we'll use PyTorch's functions to do it for us, and also to handle mini-batches (which we didn't do last time, since our dataset was so small)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"bs=64\n",
"train_ds = TensorDataset(x_train, y_train)\n",
"valid_ds = TensorDataset(x_valid, y_valid)\n",
"data = DataBunch.create(train_ds, valid_ds, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 784]), torch.Size([64]))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = next(iter(data.train_dl))\n",
"x.shape,y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## PyTorch Class: torch.nn.modules.linear."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"#def linear(input, weight, bias=None):\n",
"def my_linear(input, weight, bias=None):\n",
" r\"\"\"\n",
" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.\n",
"\n",
" Shape:\n",
"\n",
" - Input: :math:`(N, *, in\\_features)` where `*` means any number of\n",
" additional dimensions\n",
" - Weight: :math:`(out\\_features, in\\_features)`\n",
" - Bias: :math:`(out\\_features)`\n",
" - Output: :math:`(N, *, out\\_features)`\n",
" \"\"\"\n",
" \n",
" \"\"\"\n",
" if input.dim() == 2 and bias is not None:\n",
" # fused op is marginally faster\n",
" #return torch.addmm(bias, input, weight.t())\n",
" \n",
" #my_weight = weight.detach().numpy()\n",
" return np.matmul(input, weight.t(), bias )\n",
" \"\"\"\n",
" \n",
" output = input.matmul(weight.t())\n",
" #my_weight = weight.detach().numpy()\n",
" #output = np.matmul(input, weight.t())\n",
" if bias is not None:\n",
" output += bias\n",
" return output\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"#class Linear(Module):\n",
"class my_Linear(Module):\n",
" r\"\"\"Applies a linear transformation to the incoming data: :math:`y = xA^T + b`\n",
"\n",
" Args:\n",
" in_features: size of each input sample\n",
" out_features: size of each output sample\n",
" bias: If set to False, the layer will not learn an additive bias.\n",
" Default: ``True``\n",
"\n",
" Shape:\n",
" - Input: :math:`(N, *, in\\_features)` where :math:`*` means any number of\n",
" additional dimensions\n",
" - Output: :math:`(N, *, out\\_features)` where all but the last dimension\n",
" are the same shape as the input.\n",
"\n",
" Attributes:\n",
" weight: the learnable weights of the module of shape\n",
" `(out_features x in_features)`\n",
" bias: the learnable bias of the module of shape `(out_features)`\n",
"\n",
" Examples::\n",
"\n",
" >>> m = nn.Linear(20, 30)\n",
" >>> input = torch.randn(128, 20)\n",
" >>> output = m(input)\n",
" >>> print(output.size())\n",
" \"\"\"\n",
"\n",
" def __init__(self, in_features, out_features, bias=True):\n",
" #super(Linear, self).__init__()\n",
" super(my_Linear, self).__init__()\n",
" #super().__init__()\n",
" \n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.weight = Parameter(torch.Tensor(out_features, in_features))\n",
" \n",
" if bias:\n",
" self.bias = Parameter(torch.Tensor(out_features))\n",
" else:\n",
" self.register_parameter('bias', None)\n",
" self.reset_parameters()\n",
"\n",
" def reset_parameters(self):\n",
" stdv = 1. / math.sqrt(self.weight.size(1)) \n",
" self.weight.data.uniform_(-stdv, stdv) \n",
" if self.bias is not None:\n",
" self.bias.data.uniform_(-stdv, stdv)\n",
"\n",
" def forward(self, input):\n",
" #return F.linear(input, self.weight, self.bias)\n",
" return my_linear(input, self.weight, self.bias)\n",
"\n",
" def extra_repr(self):\n",
" return 'in_features={}, out_features={}, bias={}'.format(\n",
" self.in_features, self.out_features, self.bias is not None\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class Mnist_Logistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = my_Linear(784, 10, bias=True)\n",
"\n",
" def forward(self, xb): return self.lin(xb)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model = Mnist_Logistic().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Mnist_Logistic(\n",
" (lin): my_Linear(in_features=784, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"my_Linear(in_features=784, out_features=10, bias=True)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.lin"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 10])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(x).shape"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[torch.Size([10, 784]), torch.Size([10])]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[p.shape for p in model.parameters()]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"lr=2e-2"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"loss_func = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def update(x,y,lr):\n",
" wd = 1e-5\n",
" y_hat = model(x)\n",
" # weight decay\n",
" w2 = 0.\n",
" for p in model.parameters(): w2 += (p**2).sum()\n",
" # add to regular loss\n",
" loss = loss_func(y_hat, y) + w2*wd\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters():\n",
" p.sub_(lr * p.grad)\n",
" p.grad.zero_()\n",
" return loss.item()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"losses = [update(x,y,lr) for x,y in data.train_dl]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses);"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"class Mnist_NN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin1 = nn.Linear(784, 50, bias=True)\n",
" self.lin2 = nn.Linear(50, 10, bias=True)\n",
"\n",
" def forward(self, xb):\n",
" x = self.lin1(xb)\n",
" x = F.relu(x)\n",
" return self.lin2(x)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"model = Mnist_NN().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"losses = [update(x,y,lr) for x,y in data.train_dl]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses);"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"model = Mnist_NN().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def update(x,y,lr):\n",
" opt = optim.Adam(model.parameters(), lr)\n",
" y_hat = model(x)\n",
" loss = loss_func(y_hat, y)\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" return loss.item()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"losses = [update(x,y,1e-3) for x,y in data.train_dl]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses);"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Mnist_NN(), loss_func=loss_func, metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XmUXHWd9/H3t6p6SXrJ1p2FdJNOIOwgIS07DIvDgPqADOjDHB0FZZhRGXEcx5HxPMyI4yg47s4josig4LiAPgeQVSEsKoFOgLAkkEBC0pBOOkmnl/RaXd/nj7pdKZrupEnqVtWt/rzOuadv3fu7db9VqdS3fve3XHN3REREAGKFDkBERIqHkoKIiGQoKYiISIaSgoiIZCgpiIhIhpKCiIhkKCmIiEhGqEnBzDaY2XNm9oyZtYyx38zsO2a2zsxWmdlxYcYjIiJ7lsjDOc50923j7DsPWBwsJwDfD/6KiEgB5CMp7MkFwE88Paz6CTObbmbz3H3zeAfU1dV5U1NT3gIUESkFK1as2Obu9XsrF3ZScOABM3PgB+5+46j984FNWY9bg23jJoWmpiZaWt5yJUpERPbAzF6bSLmwk8Ip7v6Gmc0GHjSzNe7+aNZ+G+OYt0zGZGZXAFcAHHjggeFEKiIi4TY0u/sbwd+twG+A40cVaQUasx43AG+M8Tw3unuzuzfX1++19iMiIvsotKRgZlVmVjOyDpwDPD+q2J3Ah4NeSCcCnXtqTxARkXCFefloDvAbMxs5z8/c/T4z+zsAd78BuAd4N7AO6AUuCzEeERHZi9CSgru/CrxjjO03ZK078MmwYhARkbdHI5pFRCRDSUFERDKUFEREilz/0DA3PPIKK17rCP1chR7RLCIi43B3fvvcZr567xpaO/r4uz87iKULZoR6TiUFEZEisGsgyZfvWU0q5VRVJKiqSPCHddtY8VoHh82t4daPncCpi+tCj0NJQUSkCDy0Zis/W76Ruupy+gaH2TU4zOyaCq676GguXtpIPDbWBBC5p6QgIlIElq/fTnVFgieuPptEPEYqlZ7xJ5anZDBCSUFEpAgsf3UHSxfMIBFP9//JdzIYod5HIiIFtq1ngLVbezhx0axCh6KkICJSaE+u3wHACYtmFjgSJQURkYJb/up2ppbHOXr+tEKHoqQgIlJoy9en2xPK4oX/Si58BCIik1jHrkHWtHVzwsLCXzoCJQURkYJanmlPKHwjMygpiIgU1PL126lIxDimofDtCaCkICJSUCPjEyoS8UKHAigpiIgUTGfvEKvbujhhYXFcOgIlBRGRgnlyww7ci2N8wgglBRGRAln+6nbKEzGObZxe6FAylBRERArkqQ07OLZxOpVlxdGeAEoKIiIFMZxy1rR1c0wRjGLOpqQgIlIAm3b0MpBMsXhOdaFDeRMlBRGRAli7tQeAxXNqChzJm4WeFMwsbmZPm9ndY+y71MzazeyZYLk87HhERIrBy1u6AVg8u7hqCvm4yc5VwGqgdpz9v3D3K/MQh4hI0Vi7pZt50yqpqSwrdChvEmpNwcwagPcAPwrzPCIiUbN2a0/RXTqC8C8ffQv4HJDaQ5mLzGyVmd1uZo0hxyMiUnDDKWfd1p6iu3QEISYFM3svsNXdV+yh2F1Ak7sfA/wOuGWc57rCzFrMrKW9vT2EaEVE8mek59EhRdbzCMKtKZwCnG9mG4CfA2eZ2a3ZBdx9u7sPBA9/CCwd64nc/UZ3b3b35vr6+hBDFhEJX7H2PIIQk4K7X+3uDe7eBFwCPOTuH8ouY2bzsh6eT7pBWkSkpI30PDq4CC8f5aP30ZuY2bVAi7vfCXzKzM4HksAO4NJ8xyMikm/rtvYwb1oltUXW8wjylBTcfRmwLFi/Jmv71cDV+YhBRKRYvLyluygvHYFGNIuI5FUx9zwCJQURkbxq7SjenkegpCAiklcvb0n3PDp4ti4fiYhMepk5j1RTEBGRYu55BEoKIiJ59fKW7qIcnzBCSUFEJE9Geh4dUqTdUUFJQUQkb4q95xEoKYiI5E2x9zwCJQURkbxZv20kKaimICIy6XX0DlEWN2or8z7t3IQpKYiI5ElX3xC1lWWYWaFDGZeSgohInnT1J6mdUpzjE0YoKYiI5Em6plC8l45ASUFEJG+6+odUUxARkbSRNoVipqQgIpInnX1Jaqfo8pGIiBBcPlJNQURE+oeGGUym1KYgIiLpWgKgpCAiItDVlwRQl1QREVFNQUREsnT1BUlhsjc0m1nczJ42s7vH2FdhZr8ws3VmttzMmsKOR0SkELr605ePpqlLKlcBq8fZ9zGgw90PBr4JXJeHeERE8q5TNQUwswbgPcCPxilyAXBLsH47cLYV8/SBIiL7KHP5aJK3KXwL+ByQGmf/fGATgLsngU5gVsgxiYjkXVf/EOWJGJVl8UKHskehJQUzey+w1d1X7KnYGNt8jOe6wsxazKylvb09ZzGKiORLV1+y6C8dQbg1hVOA881sA/Bz4Cwzu3VUmVagEcDMEsA0YMfoJ3L3G9292d2b6+vrQwxZRCQc6RlSi7uRGUJMCu5+tbs3uHsTcAnwkLt/aFSxO4GPBOsXB2XeUlMQEYm6KMyQCpD3tGVm1wIt7n4ncBPwUzNbR7qGcEm+4xERyYeu/iTTiryRGfKUFNx9GbAsWL8ma3s/8P58xCAiUkjdfUM0zphS6DD2SiOaRUTyoLOv+O+6BkoKIiKhc3e6+ocicflISUFEJGT9QymGhj0SDc1KCiIiIds9Q+ok7pIqIiJpUZkhFZQURERCF5V7KYCSgohI6KJy1zVQUhARCV1nRGZIBSUFEZHQZS4fqU1BRER230tBl49ERCa9rv4klWUxKhLFfS8FUFIQEQldVGZIBSUFEZHQpe+loKQgIiKM3HWt+NsTQElBRCR0qimIiEhGp9oURERkRFdfNKbNBiUFEZFQpe+lkIzEGAVQUhARCVXv4DDDqWjcSwGUFEREQhWlGVJBSUFEJFS7Z0hVUhARmfSidNc1UFIQEQlVlO66BiEmBTOrNLMnzexZM3vBzL44RplLzazdzJ4JlsvDikdEpBCidC8FgDDrMwPAWe7eY2ZlwONmdq+7PzGq3C/c/coQ4xARKZiRmkJUximElhTc3YGe4GFZsHhY5xMRKUZd/emG5hrNfQRmFjezZ4CtwIPuvnyMYheZ2Sozu93MGsOMR0Qk37r6hphaHqcsHo0m3FCjdPdhdz8WaACON7OjRhW5C2hy92OA3wG3jPU8ZnaFmbWYWUt7e3uYIYuI5FRXf3TmPYI89T5y953AMuDcUdu3u/tA8PCHwNJxjr/R3Zvdvbm+vj7UWEVEcqmrLzpTXMAEk4KZHWRmFcH6GWb2KTObvpdj6kfKmNkU4F3AmlFl5mU9PB9Y/XaCFxEpdqVaU7gDGDazg4GbgIXAz/ZyzDzgYTNbBTxFuk3hbjO71szOD8p8Kuiu+izwKeDSt/0KRESKWJTupQAT732UcvekmV0IfMvdv2tmT+/pAHdfBSwZY/s1WetXA1e/nYBFRKKks2+IxbNrCh3GhE20pjBkZn8FfAS4O9gWndQnIlIgUboVJ0w8KVwGnAR82d3Xm9lC4NbwwhIRib5UyukuxctH7v4i6Wv+mNkMoMbdvxpmYCIiUbdrMEnKozPvEUy899EyM6s1s5nAs8DNZvaNcEMTEYm2jl3BFBdTSywpANPcvQv4S+Bmd19KuoupiIiMo72nH4DZNRUFjmTiJpoUEsGYgg+wu6FZRET2oL07PTa3vgSTwrXA/cAr7v6UmS0C1oYXlohI9LX3DAJQXx2dpDDRhuZfAb/KevwqcFFYQYmIlIL27gHMYGZVeaFDmbCJNjQ3mNlvzGyrmW0xszvMrCHs4EREoqy9e4BZVeUkIjJDKkz88tHNwJ3AAcB80rOb3hxWUCIipWBbzwB1Ebp0BBNPCvXufrO7J4PlvwFNVyoisgft3QORamSGiSeFbWb2oeCmOXEz+xCwPczARESirr17IFKNzDDxpPBR0t1R24DNwMWkp74QEZExuDvtPSVaU3D3je5+vrvXu/tsd38f6YFsIiIyhu6BJIPJVMm2KYzlMzmLQkSkxERx4BrsX1KwnEUhIlJiJmNS8JxFISJSYrb1RDMp7HFEs5l1M/aXvwFTQolIRKQEjNQUotamsMek4O7RuYeciEgRae8eIBEzpkfoBjuwf5ePRERkHO3d6dHMsVi0ml+VFEREQrCtZ4C6muhMhDdCSUFEJATtPdEbzQxKCiIioYjivEcQYlIws0oze9LMnjWzF8zsi2OUqTCzX5jZOjNbbmZNYcUjIpIvqZSzrWdQSWGUAeAsd38HcCxwrpmdOKrMx4AOdz8Y+CZwXYjxiIjkxc6+IYZTHrnuqBBiUvC0nuBhWbCMHvNwAXBLsH47cLaZRaupXkRklKiOZoaQ2xSCabafAbYCD7r78lFF5gObANw9CXQCs8Z4nivMrMXMWtrb28MMWURkv2WSgmoKb+buw+5+LNAAHG9mR40qMlat4C0jqN39Rndvdvfm+nrd20dEiltUp7iAPPU+cvedwDLg3FG7WoFGADNLANOAHfmISUQkLJkpLpQUdjOzejObHqxPAd4FrBlV7E7gI8H6xcBD7q6J9kQk0tp7BqhIxKip2ONMQkUpzIjnAbeYWZx08vmlu99tZtcCLe5+J3AT8FMzW0e6hnBJiPGIiOTFyBiFKPabCS0puPsqYMkY26/JWu8H3h9WDCIihbCtZyCS3VFBI5pFRHIuqqOZQUlBRCTnlBRERASA5HCKHb2DkRyjAEoKIiI5tWPXIO7R7I4KSgoiIjm1NcKjmUFJQUQkp9ojPJoZlBRERHJqW1BTmK2kICIiIzUFjVMQERHauweorkgwpTxe6FD2iZKCiEgORfWOayOUFEREcmhLV39kex6BkoKISE61dfYzd1plocPYZ0oKIiI54u60dfUzT0lBREQ6eocYTKZUUxAREdjc2QfA3FolBRGRSa+tsx9ANQUREYG2rnRSmDdtSoEj2XdKCiIiOdLW2U/MoK66vNCh7DMlBRGRHNnc2c/smkoS8eh+tUY3chGRIrOlK9pjFEBJQUQkZzZ3RnuMAigpiIjkTFtnP3Mi3B0VlBRERHKiu3+InoGkagrjMbNGM3vYzFab2QtmdtUYZc4ws04zeyZYrgkrHhGRMG3piv4YBYBEiM+dBP7R3VeaWQ2wwswedPcXR5V7zN3fG2IcIiKh2zwycE2Xj8bm7pvdfWWw3g2sBuaHdT4RkUIaSQpRHrgGeWpTMLMmYAmwfIzdJ5nZs2Z2r5kdOc7xV5hZi5m1tLe3hxipiMi+2RIkhdm10b2XAuQhKZhZNXAH8Gl37xq1eyWwwN3fAXwX+H9jPYe73+juze7eXF9fH27AIiL7YHNXPzOryqksi+ZtOEeEmhTMrIx0QrjN3X89er+7d7l7T7B+D1BmZnVhxiQiEoa2zv7ItydAuL2PDLgJWO3u3xinzNygHGZ2fBDP9rBiEhEJS1sJDFyDcHsfnQL8NfCcmT0TbPsX4EAAd78BuBj4uJklgT7gEnf3EGMSEQlFW1c/xx44vdBh7LfQkoK7Pw7YXsp8D/heWDGIiORD/9AwO3YNMk+Xj0REpFQGroGSgojIfmsrkTEKoKQgIrLf2jI1hWiPUQAlBRGR/ZaZ4kI1BRERaevsp6YiQXVFmB068yP6r2CCnnh1O999aC1l8ViwGGZGcjjFcMpJphwDYmbEYkbMIB4zYmYkYiPbLFPGcVIOKXdGd6I1g7gZiXj6mJQ7yWFn2J1Uyhl2GE6lzzuccoaGnWQqRXI4/Vwjz52OK719OOVvii1me+zYBaSfB0g/Z1aMsVg6vljMMn+zX2+mnBnliRjliRgV8RjxmDHsnonbzIgHsZgZ7ukzprJOZlnPlX0es2DdRtbTj83YHZcZ8ZhRFjfisfS/WfpcZMrHYzES8fS/USIeozxuJGIxyhIxKhIxKsviVJbFqEzEKUvEKI+nl4qy9H6bwPsosjdtnf3MKYFGZphESSGVcgaGUvT0JxkcdoaGU6TcKQu+VOKx9JfDcCr4sk/t/hJPpjzz5Z/y9HrMdn9BmYFl9b5NZX1xjpQd+cKNx0bW01+OieDLLhF86cZiELNYOrHEYiSC8vEgEQ2nwIMYxvtC86x9IyWCIYKQlXBG4ky5k0rB0HAq88WeTiTO4LAzkBxmMJlOYvE3xbP7tbqT+bJOvx/gmXh2J8/h4H0d2ZY+Nuuxp+PzTGy5/yxkG0kc2a+rLGFUJOJMCRJKVUWCmsoyqisS1Famfw1WB39rp5QxfUoZM6rKM3/LInx/Xtk3m7tKY+AaTKKkcPLBdZx8sGbQiJqR5JwcdoZSKTwVJBjIJN6R2lQylUrXuoadweEUA8lh+oeG6R9K0T80nNk+mEwxkEwF+9JLdg1oaNgz2/uGhtneM8hr23vp7h+iqz/JYDK1x5hnTC1jVnUFs2sqmD99CvNnTGH+9Ck0zJhK48wpzJs2JfMjREpDW2cfh8wujXnZJk1SkGiKxYwYRlkcplAcE40NJlPsGkjSM5Cks2+Inb1D7OwbpGPXINt3DbKtZ4DtPYNs6ern0bXtbOkaeNPxiZgxf8YUDp1Tw+HzajnigFoOn1tLw4wpxJQsIic5nKK9e6AkxiiAkoLI25ZuZylnRlU5jRMoP5AcZvPOflo7+mjt6GVTRy8btveyZnMXD67ekmnvqSqPs3hODYfNreGYhuk0N83g4PpqJYoi194zQMpLY+AaKCmIhK4iEaeproqmuqq37OsdTPJSWzdr2rqDv13c90IbP39qEwC1lQmWLpjBqYvr+bND6jiovlqN40Vm9811lBREZD9NLU+w5MAZLDlwRmabu7Nhey8tG3awcmMHT7y6g4dfepEvAQdMq+TMw2bznmPmccLCWWqbKAKvd/QBMLc2+mMUQElBpOiYGQvrqlhYV8X7m9MXqDbt6OXRte08+nI7v175Orct30hddTnnHjWXC5fM57gDZ6gGUSCrWndSHo9x0Oy31gSjSElBJAIaZ07lgycs4IMnLKB3MMnDa9q557nN3L6ilVuf2MjCuir+csl8LlrawAHTS+MXa1Q8vXEnR82vpSJRHB0h9peSgkjETC1P8J5j5vGeY+bRM5Dknuc2c8eKVr7+4Mt8+/dr+eAJB3LlWYupr4n+PDzFbjCZYtXrnfz1iQsKHUrOKCmIRFh1RYIPNDfygeZGNu3o5YZHXuHW5Rv51YpWLj9tEX9z2kJqKssKHWbJWr25i8FkiuOy2oSiTkMvRUpE48ypfPnCo3nwH07nzENn853fr+W06x/m+8teoXcwWejwStLKjR0ALCmBO66NUFIQKTGL6qv5rw8ex51XnsKxjdO57r41nHbdw/zosVcZSA4XOryS8vTGncytrSypdhwlBZESdUzDdP77suO54+Mnc/i8Wv79t6s555uP8sALbehW6LmxcmMHxy0onVoCKCmIlLylC2Zw6+Un8NOPHU95PMYVP13Bh3/8JC9v6S50aJG2tTs9Sn1JY+m0J4CSgsikcdrieu656jT+9X8dwbObdnLetx/ji3e9QGffUKFDi6SnN+4EUE1BRKKrLB7jslMWsuyfzuSSdzby33/cwFn/uYyfP7mRVNjzlJeYpzfupCxuHHnAtEKHklNKCiKT0Myqcr584dHcdeWpLKqv4vO/fo5LfvgErR29hQ4tMlZu7OCIA6ZRWVYag9ZGhJYUzKzRzB42s9Vm9oKZXTVGGTOz75jZOjNbZWbHhRWPiLzVUfOn8cu/PYnrLz6GF9/o4rxvPcavV7aqIXovksMpVrXuZEljaV06gnBrCkngH939cOBE4JNmdsSoMucBi4PlCuD7IcYjImMwMz7Q3Mi9V53GYfNq+Mwvn+UTt61k/bZdhQ6taK1p66Z/KMVxC0qrkRlCTAruvtndVwbr3cBqYP6oYhcAP/G0J4DpZjYvrJhEZHyNM6fy8ytO4nPnHspDa7Zy9teX8ZlfPMMr7T2FDq3oPD0yaE01hX1jZk3AEmD5qF3zgU1Zj1t5a+LAzK4wsxYza2lvbw8rTJFJLx4zPnHGwTz2z2fysVMXcu/zbbzrG4/w+TtW0TOgUdEjVm7cSX1NBQ0zSmfQ2ojQk4KZVQN3AJ92967Ru8c45C0XM939Rndvdvfm+vrSuA+qSDGbXVPJF95zBI/985l89JSF/LJlE+d9+1FaNuwodGgF5+6s3NjBksbpJTldeahJwczKSCeE29z912MUaYU33dGwAXgjzJhEZOLqqiv4P+89gl/+7UkAfOAHf+Jr96+hb3DyTpexcmMHr23v5fRDSvMHapi9jwy4CVjt7t8Yp9idwIeDXkgnAp3uvjmsmERk3zQ3zeTeq07n4qUN/NfDr3DSV3/Pf97/Elu7+gsdWt79+PEN1FYm+Mvj3nKluySEOXX2KcBfA8+Z2TPBtn8BDgRw9xuAe4B3A+uAXuCyEOMRkf1QXZHg+ovfwQeaG/nhY6/yX8vW8YNHX+GMQ2fTNGsq86ZN4YDplZy0qI5pU0tzuu7Xd/Zx3wttXH7qQqaWl+adB0J7Ve7+OGO3GWSXceCTYcUgIrnX3DST5qaZvLZ9Fzf/YQOPvJy+TehAMgVAXXU5115wFO8+uvQ6Ev7kTxsA+PDJTYUMI1SlmepEJHQLZlXxb+cfCaQbX3f2DvHylm7+/ber+cRtKzn3yLlc+74jmV1TWeBIc6N3MMn/LN/IuUfOZX4JTZU9mpKCiOw3M2NGVTknLJrFbz5xMj98bD3f/N3L/PHr2/jgiQv40IkLJvxFOjSc4jdPv862ngGaZlXRNKuKBbOmUlVR2K+rO1a+Tld/kstOaSpoHGFTUhCRnErEY3z8jIM458g5XH/fGn7wyCv84JFXOOeIuZx+SD29g0l6BpL0DQ5z6NwaTl1cx+yaStyd+19o4/r7XuLVUaOpzeDYxumcfdhszj58DofNrclrd9BUyrn5D+s5pmEaS0twFHM2i9ocJ83Nzd7S0lLoMERkglo7evnpE6/xi6c2sbN39zTdZXFjaDj9/XPY3BrKEzFWtXZy8OxqPn/uYZx40Cw2bNvFa9t7eWlLN4+8tJVnWzsBaJw5hYuPa+SipfNpmDE19Nfwuxe3cPlPWvjW/z6W9y2JZq8jM1vh7s17LaekICL50D80zI5dg1RVJKiuSGDAi5u7eGztNh5b286Wrn7+5rRFXLy0gUR87N7yW7v6eWjNVu5etZnH123DDE49uI4PnbiAPz98DrFYbmsPvYNJvvP7ddz0+KvMqa3koX88g/JENCeXVlIQkZK2aUcvd6xs5Vctrby+s49F9VVccdoiLjxuPhWJ/ZvO2t257/k2vnT3i7zR2c/FSxv4/HmHUVddkaPo809JQUQmheRwinufb+OGR17hhTe6qKsu512Hz+HMw2Zz6sF1b7uBeuXGDr5yz2qe2tDB4fNq+dIFR9LcNDOk6PNHSUFEJhV35w/rtnPb8td4bO02egaSlMdjHDW/lkX11Syqr2LhrCoaZ06lccZUaqckMDNSKWdbzwAbd/Ry0+Pruff5NuqqK/j0uxZzyTsbx72UFTVKCiIyaQ0mU7S8toOH12xlVWsn67ftYmv3wJvK1FQkqJ1Sxtbu/kyDd1V5nCtOP4jLT1tY8C6wuTbRpFBar1pEBChPxDj5oDpOPqgus61nIMn69l20dvTS2tHH6zv76OwbYk5tJfOnVzJv2hSWHDidWRFuN8gFJQURmRSqKxIc3TCNoxumFTqUolYaF8tERCQnlBRERCRDSUFERDKUFEREJENJQUREMpQUREQkQ0lBREQylBRERCQjctNcmFk7sBPoHLVr2l627W195G8dsG0fQhvr/BPZP3r7nh6PjjV7277Enc+Ys9cL8V7r86HPx572R/Hz8XZiBljs7nsfuefukVuAG9/utr2tZ/1tyVVME9k/evueHo+OdX/jzmfMhX6v9fnQ56PUPh9vJ+aJnGNkierlo7v2Ydve1sc6fn9jmsj+0dv39HisWPcn7nzGnL1eiPdan4+3T5+Pia8Xe8wTOQcQwctHYTOzFp/ATILFJopxK+b8iWLcirkwolpTCNONhQ5gH0UxbsWcP1GMWzEXgGoKIiKSoZqCiIhklHRSMLMfm9lWM3t+H45dambPmdk6M/uOmVnWvr83s5fM7AUzuz63UYcTt5n9m5m9bmbPBMu7iz3mrP2fNTM3s7rxnmNfhPQ+f8nMVgXv8QNmdkAEYv6ama0J4v6NmU3PZcwhxv3+4P9gysxydh1/f2Id5/k+YmZrg+UjWdv3+LkvmH3pPhWVBTgdOA54fh+OfRI4CTDgXuC8YPuZwO+AiuDx7IjE/W/AZ6P0Xgf7GoH7gdeAumKPGajNKvMp4IYIxHwOkAjWrwOui8LnAzgcOBRYBjQXOtYgjqZR22YCrwZ/ZwTrM/b0ugq9lHRNwd0fBXZkbzOzg8zsPjNbYWaPmdlho48zs3mk/3P/ydP/ej8B3hfs/jjwVXcfCM6xNSJxhyrEmL8JfA7IeeNXGDG7e1dW0apcxx1SzA+4ezIo+gTQkMuYQ4x7tbu/VCyxjuMvgAfdfYe7dwAPAucW8v/q3pR0UhjHjcDfu/tS4LPA/x2jzHygNetxa7AN4BDgNDNbbmaPmNk7Q412t/2NG+DK4BLBj81sRnihZuxXzGZ2PvC6uz8bdqBZ9vt9NrMvm9km4IPANSHGOiIXn40RHyX9qzUfchl32CYS61jmA5uyHo/EXyyv6y0m1T2azawaOBn4Vdblu7Hu0j3Wtb2RX3wJ0tXAE4F3Ar80s0VBtg9FjuL+PvCl4PGXgK+T/gIIxf7GbGZTgS+QvrSRFzl6n3H3LwBfMLOrgSuBf81xqLsDyVHMwXN9AUgCt+UyxrHkMu6w7SlWM7sMuCrYdjBwj5kNAuvd/ULGj7/gr2s8kyopkK4Z7XT3Y7M3mlkcWBE8vJP0F2h2FboBeCNYbwV+HSSBJ80sRXq+k/Zijtvdt2Qd90Pg7hDjhf2P+SBgIfBs8B+xAVhpZse7e1uRxjzaz4DfEmJSIEcxBw2g7wXODvMHTpZcv9dhGjNWAHe/GbgZwMyWAZd+08QbAAAEJElEQVS6+4asIq3AGVmPG0i3PbRS+Nc1tkI3aoS9AE1kNRgBfwTeH6wb8I5xjnuKdG1gpBHo3cH2vwOuDdYPIV01tAjEPS+rzD8APy/2mEeV2UCOG5pDep8XZ5X5e+D2CMR8LvAiUJ/rWPPx+SDHDc37GivjNzSvJ311YUawPnOin/tCLAUPINQXB/8DbAaGSGfmj5H+9Xkf8GzwH+GacY5tBp4HXgG+x+6BfuXArcG+lcBZEYn7p8BzwCrSv8DmFXvMo8psIPe9j8J4n+8Itq8iPdfM/AjEvI70j5tngiWnPaZCjPvC4LkGgC3A/YWMlTGSQrD9o8F7vA647O187guxaESziIhkTMbeRyIiMg4lBRERyVBSEBGRDCUFERHJUFIQEZEMJQUpCWbWk+fz/cjMjsjRcw1belbV583srr3NUmpm083sE7k4t8ho6pIqJcHMety9OofPl/Ddk8SFKjt2M7sFeNndv7yH8k3A3e5+VD7ik8lFNQUpWWZWb2Z3mNlTwXJKsP14M/ujmT0d/D002H6pmf3KzO4CHjCzM8xsmZndbun7Ddw2Mud9sL05WO8JJsF71syeMLM5wfaDgsdPmdm1E6zN/IndEwJWm9nvzWylpefdvyAo81XgoKB28bWg7D8F51llZl/M4dsok4ySgpSybwPfdPd3AhcBPwq2rwFOd/clpGcx/Y+sY04CPuLuZwWPlwCfBo4AFgGnjHGeKuAJd38H8CjwN1nn/3Zw/r3OaxPM+3M26RHnAP3Ahe5+HOn7eHw9SEqfB15x92Pd/Z/M7BxgMXA8cCyw1MxO39v5RMYy2SbEk8nlXcARWTNb1ppZDTANuMXMFpOembIs65gH3T17Lv0n3b0VwMyeIT0nzuOjzjPI7gkGVwB/HqyfxO458n8G/Oc4cU7Jeu4VpOfch/ScOP8RfMGnSNcg5oxx/DnB8nTwuJp0knh0nPOJjEtJQUpZDDjJ3fuyN5rZd4GH3f3C4Pr8sqzdu0Y9x0DW+jBj/58Z8t2Nc+OV2ZM+dz/WzKaRTi6fBL5D+n4M9cBSdx8ysw1A5RjHG/AVd//B2zyvyFvo8pGUsgdI388AADMbmfp4GvB6sH5piOd/gvRlK4BL9lbY3TtJ38Lzs2ZWRjrOrUFCOBNYEBTtBmqyDr0f+Ggw7z9mNt/MZufoNcgko6QgpWKqmbVmLZ8h/QXbHDS+vkh62nOA64GvmNkfgHiIMX0a+IyZPQnMAzr3doC7P016Js5LSN/sptnMWkjXGtYEZbYDfwi6sH7N3R8gfXnqT2b2HHA7b04aIhOmLqkiIQnuHtfn7m5mlwB/5e4X7O04kUJSm4JIeJYC3wt6DO0kxNufiuSKagoiIpKhNgUREclQUhARkQwlBRERyVBSEBGRDCUFERHJUFIQEZGM/w9WmoNexmC4YQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:03 <p><table style='width:300px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>0.147097</th>\n",
" <th>0.132961</th>\n",
" <th>0.961600</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(1, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_lr(show_moms=True)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_losses()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fin"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment