Last active
November 25, 2022 11:20
-
-
Save PiotrCzapla/00c82fb193c9ebc20702ea22de2cb737 to your computer and use it in GitHub Desktop.
fastai 04_minibatch_training_withoutbody.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 121, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt\n", | |
"from pathlib import Path\n", | |
"from torch import tensor,nn\n", | |
"import torch.nn.functional as F\n", | |
"from fastcore.test import test_close\n", | |
"\n", | |
"torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", | |
"torch.manual_seed(1)\n", | |
"mpl.rcParams['image.cmap'] = 'gray'\n", | |
"\n", | |
"path_data = Path('data')\n", | |
"path_gz = path_data/'mnist.pkl.gz'\n", | |
"with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", | |
"x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Initial setup" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"### Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 122, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"n,m = x_train.shape\n", | |
"c = y_train.max()+1\n", | |
"nh = 50" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(self, n_in, nh, n_out):\n", | |
" super().__init__()\n", | |
" self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]\n", | |
" \n", | |
" def __call__(self, x):\n", | |
" for l in self.layers: x = l(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 124, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([50000, 10])" | |
] | |
}, | |
"execution_count": 124, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = Model(m, nh, 10)\n", | |
"pred = model(x_train)\n", | |
"pred.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"### Cross entropy loss" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"First, we will need to compute the softmax of our activations. This is defined by:\n", | |
"\n", | |
"$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \\cdots + e^{x_{n-1}}}$$\n", | |
"\n", | |
"or more concisely:\n", | |
"\n", | |
"$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{\\sum_{0 \\leq j \\leq n-1} e^{x_{j}}}$$ \n", | |
"\n", | |
"In practice, we will need the log of the softmax when we calculate the loss." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 125, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def log_softmax(x): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 126, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", | |
" [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", | |
" [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", | |
" ...,\n", | |
" [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", | |
" [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", | |
" [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<LogBackward0>)" | |
] | |
}, | |
"execution_count": 126, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"log_softmax(pred)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Note that the formula \n", | |
"\n", | |
"$$\\log \\left ( \\frac{a}{b} \\right ) = \\log(a) - \\log(b)$$ \n", | |
"\n", | |
"gives a simplification when we compute the log softmax, which was previously defined as `(x.exp()/(x.exp().sum(-1,keepdim=True))).log()`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 127, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def log_softmax(x): ..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:\n", | |
"\n", | |
"$$\\log \\left ( \\sum_{j=1}^{n} e^{x_{j}} \\right ) = \\log \\left ( e^{a} \\sum_{j=1}^{n} e^{x_{j}-a} \\right ) = a + \\log \\left ( \\sum_{j=1}^{n} e^{x_{j}-a} \\right )$$\n", | |
"\n", | |
"where a is the maximum of the $x_{j}$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def logsumexp(x):\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 129, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def log_softmax(x): ... # numerical stability achieved using logsumexp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 130, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", | |
" [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", | |
" [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", | |
" ...,\n", | |
" [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", | |
" [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", | |
" [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<SubBackward0>)" | |
] | |
}, | |
"execution_count": 130, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_close(logsumexp(pred), pred.logsumexp(-1))\n", | |
"sm_pred = log_softmax(pred)\n", | |
"sm_pred" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"The cross entropy loss for some target $x$ and some prediction $p(x)$ is given by:\n", | |
"\n", | |
"$$ -\\sum x\\, \\log p(x) $$\n", | |
"\n", | |
"But since our $x$s are 1-hot encoded, this can be rewritten as $-\\log(p_{i})$ where i is the index of the desired target.\n", | |
"\n", | |
"This can be done using numpy-style [integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#integer-array-indexing). Note that PyTorch supports all the tricks in the advanced indexing methods discussed in that link." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 131, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([5, 0, 4])" | |
] | |
}, | |
"execution_count": 131, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_train[:3]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 132, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(-2.20, grad_fn=<SelectBackward0>),\n", | |
" tensor(-2.37, grad_fn=<SelectBackward0>),\n", | |
" tensor(-2.36, grad_fn=<SelectBackward0>))" | |
] | |
}, | |
"execution_count": 132, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sm_pred[0,5],sm_pred[1,0],sm_pred[2,4]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 133, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([-2.20, -2.37, -2.36], grad_fn=<IndexBackward0>)" | |
] | |
}, | |
"execution_count": 133, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sm_pred[[0,1,2], y_train[:3]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 134, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def nll(input, target): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 135, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(2.30, grad_fn=<NegBackward0>)" | |
] | |
}, | |
"execution_count": 135, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss = nll(sm_pred, y_train)\n", | |
"loss" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Then use PyTorch's implementation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 136, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 137, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"test_close(F.cross_entropy(pred, y_train), loss, 1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"## Basic training loop" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Basically the training loop repeats over the following steps:\n", | |
"- get the output of the model on a batch of inputs\n", | |
"- compare the output to the labels we have and compute a loss\n", | |
"- calculate the gradients of the loss with respect to every parameter of the model\n", | |
"- update said parameters with those gradients to make them a little bit better" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 138, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"loss_func = F.cross_entropy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 139, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([-0.09, -0.21, -0.08, 0.10, -0.04, 0.08, -0.04, -0.03, 0.01, 0.06], grad_fn=<SelectBackward0>),\n", | |
" torch.Size([64, 10]))" | |
] | |
}, | |
"execution_count": 139, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bs=64 # batch size\n", | |
"\n", | |
"xb = x_train[0:bs] # a mini-batch from x\n", | |
"preds = model(xb) # predictions\n", | |
"preds[0], preds.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 140, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(2.30, grad_fn=<NllLossBackward0>)" | |
] | |
}, | |
"execution_count": 140, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"yb = y_train[0:bs]\n", | |
"loss_func(preds, yb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 141, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([3, 9, 3, 8, 5, 9, 3, 9, 3, 9, 5, 3, 9, 9, 3, 9, 9, 5, 8, 7, 9, 5, 3, 8, 9, 5, 9, 5, 5, 9, 3, 5, 9, 7, 5, 7, 9, 9, 3, 9, 3, 5, 3, 8,\n", | |
" 3, 5, 9, 5, 9, 5, 3, 9, 3, 8, 9, 5, 9, 5, 9, 5, 8, 8, 9, 8])" | |
] | |
}, | |
"execution_count": 141, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.argmax(preds, dim=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 142, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def accuracy(out, yb): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 143, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.09)" | |
] | |
}, | |
"execution_count": 143, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy(preds, yb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 144, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"lr = 0.5 # learning rate\n", | |
"epochs = 3 # how many epochs to train for" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 145, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.3036487102508545 0.09375\n", | |
"0.12374822050333023 0.96875\n", | |
"0.09232541173696518 0.96875\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for i in range(0, n, bs):\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"## Using parameters and optim" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true, | |
"hidden": true | |
}, | |
"source": [ | |
"### Parameters" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Use `nn.Module.__setattr__`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(self, n_in, nh, n_out):\n", | |
" ... \n", | |
" \n", | |
" def __call__(self, x): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 147, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = Model(m, nh, 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 148, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"l1: Linear(in_features=784, out_features=50, bias=True)\n", | |
"l2: Linear(in_features=50, out_features=10, bias=True)\n", | |
"relu: ReLU()\n" | |
] | |
} | |
], | |
"source": [ | |
"for name,l in model.named_children(): print(f\"{name}: {l}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 149, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Model(\n", | |
" (l1): Linear(in_features=784, out_features=50, bias=True)\n", | |
" (l2): Linear(in_features=50, out_features=10, bias=True)\n", | |
" (relu): ReLU()\n", | |
")" | |
] | |
}, | |
"execution_count": 149, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 150, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Linear(in_features=784, out_features=50, bias=True)" | |
] | |
}, | |
"execution_count": 150, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.l1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 151, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def fit():\n", | |
" for epoch in range(epochs):\n", | |
" for i in range(0, n, bs):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 152, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.309433937072754 0.0625\n", | |
"0.20068740844726562 0.953125\n", | |
"0.18196897208690643 0.9375\n" | |
] | |
} | |
], | |
"source": [ | |
"fit()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Behind the scenes, PyTorch overrides the `__setattr__` function in `nn.Module` so that the submodules you define are properly registered as parameters of the model." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 153, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class DummyModule():\n", | |
" def __init__(self, n_in, nh, n_out):\n", | |
" ... \n", | |
" \n", | |
" def __setattr__(self,k,v):\n", | |
" ...\n", | |
" \n", | |
" def __repr__(self): return f'{self._modules}'\n", | |
" \n", | |
" def parameters(self):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 154, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}" | |
] | |
}, | |
"execution_count": 154, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mdl = DummyModule(m,nh,10)\n", | |
"mdl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 155, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[torch.Size([50, 784]),\n", | |
" torch.Size([50]),\n", | |
" torch.Size([10, 50]),\n", | |
" torch.Size([10])]" | |
] | |
}, | |
"execution_count": 155, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[o.shape for o in mdl.parameters()]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true, | |
"hidden": true | |
}, | |
"source": [ | |
"### Registering modules" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"We can use the original `layers` approach, but we have to register the modules." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 156, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 157, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(self, layers):\n", | |
" ...\n", | |
" \n", | |
" def __call__(self, x):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 158, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = Model(layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 159, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Model(\n", | |
" (layer_0): Linear(in_features=784, out_features=50, bias=True)\n", | |
" (layer_1): ReLU()\n", | |
" (layer_2): Linear(in_features=50, out_features=10, bias=True)\n", | |
")" | |
] | |
}, | |
"execution_count": 159, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true, | |
"hidden": true | |
}, | |
"source": [ | |
"### nn.ModuleList" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"`nn.ModuleList` does this for us." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 160, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class SequentialModel(nn.Module):\n", | |
" def __init__(self, layers):\n", | |
" ...\n", | |
" \n", | |
" def __call__(self, x):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 161, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = SequentialModel(layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 162, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"SequentialModel(\n", | |
" (layers): ModuleList(\n", | |
" (0): Linear(in_features=784, out_features=50, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=50, out_features=10, bias=True)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 162, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 163, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.3222596645355225 0.03125\n", | |
"0.14386768639087677 0.96875\n", | |
"0.08797654509544373 0.96875\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.02, grad_fn=<NllLossBackward0>), tensor(1.))" | |
] | |
}, | |
"execution_count": 163, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit()\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true, | |
"hidden": true | |
}, | |
"source": [ | |
"### nn.Sequential" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"`nn.Sequential` is a convenient class which does the same as the above:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 164, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 165, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.3087124824523926 0.09375\n", | |
"0.20316630601882935 0.90625\n", | |
"0.20330585539340973 0.921875\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.01, grad_fn=<NllLossBackward0>), tensor(1.))" | |
] | |
}, | |
"execution_count": 165, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit()\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 166, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Sequential(\n", | |
" (0): Linear(in_features=784, out_features=50, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=50, out_features=10, bias=True)\n", | |
")" | |
] | |
}, | |
"execution_count": 166, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true, | |
"hidden": true | |
}, | |
"source": [ | |
"### optim" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"Let's replace our previous manually coded optimization step:\n", | |
"\n", | |
"```python\n", | |
"with torch.no_grad():\n", | |
" for p in model.parameters(): p -= p.grad * lr\n", | |
" model.zero_grad()\n", | |
"```\n", | |
"\n", | |
"and instead use just:\n", | |
"\n", | |
"```python\n", | |
"opt.step()\n", | |
"opt.zero_grad()\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 167, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Optimizer():\n", | |
" def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr\n", | |
" \n", | |
" def step(self):\n", | |
" ...\n", | |
"\n", | |
" def zero_grad(self):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 168, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 169, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"opt = Optimizer(model.parameters())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 170, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.30017352104187 0.0625\n", | |
"0.13068024814128876 0.96875\n", | |
"0.11748667806386948 0.96875\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for i in range(0, n, bs):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"PyTorch already provides this exact functionality in `optim.SGD` (it also handles stuff like momentum, which we'll look at later)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 171, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch import optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 172, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_model():\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 173, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(2.30, grad_fn=<NllLossBackward0>)" | |
] | |
}, | |
"execution_count": 173, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model,opt = get_model()\n", | |
"loss_func(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 174, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.312685012817383 0.078125\n", | |
"0.21422098577022552 0.90625\n", | |
"0.17829009890556335 0.921875\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for i in range(0, n, bs):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Dataset and DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"### Dataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"hidden": true | |
}, | |
"source": [ | |
"It's clunky to iterate through minibatches of x and y values separately:\n", | |
"\n", | |
"```python\n", | |
" xb = x_train[s]\n", | |
" yb = y_train[s]\n", | |
"```\n", | |
"\n", | |
"Instead, let's do these two steps together, by introducing a `Dataset` class:\n", | |
"\n", | |
"```python\n", | |
" xb,yb = train_ds[s]\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 175, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Dataset():\n", | |
" def __init__(self, x, y): ...\n", | |
" def __len__(self): return ...\n", | |
" def __getitem__(self, i): return ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 176, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n", | |
"assert len(train_ds)==len(x_train)\n", | |
"assert len(valid_ds)==len(x_valid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 177, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.]]),\n", | |
" tensor([5, 0, 4, 1, 9]))" | |
] | |
}, | |
"execution_count": 177, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb,yb = train_ds[0:5]\n", | |
"assert xb.shape==(5,28*28)\n", | |
"assert yb.shape==(5,)\n", | |
"xb,yb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 178, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model,opt = get_model()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 179, | |
"metadata": { | |
"hidden": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.2979648113250732 0.09375\n", | |
"0.21198976039886475 0.953125\n", | |
"0.17290501296520233 0.921875\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(epochs):\n", | |
" for i in range(0, n, bs):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Previously, our loop iterated over batches (xb, yb) like this:\n", | |
"\n", | |
"```python\n", | |
"for i in range(0, n, bs):\n", | |
" xb,yb = train_ds[i:min(n,i+bs)]\n", | |
" ...\n", | |
"```\n", | |
"\n", | |
"Let's make our loop much cleaner, using a data loader:\n", | |
"\n", | |
"```python\n", | |
"for xb,yb in train_dl:\n", | |
" ...\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 180, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class DataLoader():\n", | |
" def __init__(self, ds, bs): self.ds,self.bs = ds,bs\n", | |
" def __iter__(self): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 181, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, bs)\n", | |
"valid_dl = DataLoader(valid_ds, bs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 182, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([64, 784])" | |
] | |
}, | |
"execution_count": 182, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb,yb = next(iter(valid_dl))\n", | |
"xb.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 183, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([3, 8, 6, 9, 6, 4, 5, 3, 8, 4, 5, 2, 3, 8, 4, 8, 1, 5, 0, 5, 9, 7, 4, 1, 0, 3, 0, 6, 2, 9, 9, 4, 1, 3, 6, 8, 0, 7, 7, 6, 8, 9, 0, 3,\n", | |
" 8, 3, 7, 7, 8, 4, 4, 1, 2, 9, 8, 1, 1, 0, 6, 6, 5, 0, 1, 1])" | |
] | |
}, | |
"execution_count": 183, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"yb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 184, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(3)" | |
] | |
}, | |
"execution_count": 184, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(xb[0].view(28,28))\n", | |
"yb[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 185, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model,opt = get_model()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 186, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fit():\n", | |
" for epoch in range(epochs):\n", | |
" for xb,yb in train_dl:\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 187, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(1.))" | |
] | |
}, | |
"execution_count": 187, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit()\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Random sampling" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 188, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 189, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Sampler():\n", | |
" def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle\n", | |
" def __iter__(self):\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 190, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from itertools import islice" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 191, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ss = Sampler(train_ds)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 192, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"1\n", | |
"2\n", | |
"3\n", | |
"4\n" | |
] | |
} | |
], | |
"source": [ | |
"it = iter(ss)\n", | |
"for o in range(5): print(next(it))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 193, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0, 1, 2, 3, 4]" | |
] | |
}, | |
"execution_count": 193, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"list(islice(ss, 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 194, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[2468, 34785, 22293, 22313, 36680]" | |
] | |
}, | |
"execution_count": 194, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ss = Sampler(train_ds, shuffle=True)\n", | |
"list(islice(ss, 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 195, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import fastcore.all as fc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 196, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class BatchSampler():\n", | |
" def __init__(self, sampler, bs, drop_last=False): fc.store_attr()\n", | |
" def __iter__(self): ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 197, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[3338, 3713, 47999, 33349],\n", | |
" [1382, 28497, 19584, 35095],\n", | |
" [33760, 20524, 1959, 7968],\n", | |
" [40952, 25061, 32207, 20443],\n", | |
" [11419, 11479, 45286, 40070]]" | |
] | |
}, | |
"execution_count": 197, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batchs = BatchSampler(ss, 4)\n", | |
"list(islice(batchs, 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 198, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def collate(b):\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 199, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class DataLoader():\n", | |
" def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr()\n", | |
" def __iter__(self): ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 200, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs)\n", | |
"valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 201, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate)\n", | |
"valid_dl = DataLoader(valid_ds, batchs=valid_samp, collate_fn=collate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 202, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(3)" | |
] | |
}, | |
"execution_count": 202, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"xb,yb = next(iter(valid_dl))\n", | |
"plt.imshow(xb[0].view(28,28))\n", | |
"yb[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 203, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([64, 784]), torch.Size([64]))" | |
] | |
}, | |
"execution_count": 203, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb.shape,yb.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 204, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.04, grad_fn=<NllLossBackward0>), tensor(1.))" | |
] | |
}, | |
"execution_count": 204, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model,opt = get_model()\n", | |
"fit()\n", | |
"\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Multiprocessing DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 205, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.multiprocessing as mp\n", | |
"from fastcore.basics import store_attr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 206, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class DataLoader():\n", | |
" def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr()\n", | |
" def __iter__(self):\n", | |
" ... " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 207, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate, n_workers=2)\n", | |
"it = iter(train_dl)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 208, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([64, 784]), torch.Size([64]))" | |
] | |
}, | |
"execution_count": 208, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb,yb = next(it)\n", | |
"xb.shape,yb.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### PyTorch DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 209, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 210, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_samp = BatchSampler(RandomSampler(train_ds), bs, drop_last=False)\n", | |
"valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 211, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate)\n", | |
"valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 212, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.09, grad_fn=<NllLossBackward0>), tensor(0.98))" | |
] | |
}, | |
"execution_count": 212, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model,opt = get_model()\n", | |
"fit()\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"PyTorch can auto-generate the BatchSampler for us:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 213, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)\n", | |
"valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"PyTorch can also generate the Sequential/RandomSamplers too:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 214, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2)\n", | |
"valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 215, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(0.98))" | |
] | |
}, | |
"execution_count": 215, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model,opt = get_model()\n", | |
"fit()\n", | |
"\n", | |
"loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Our dataset actually already knows how to sample a batch of indices all at once:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 216, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.]]),\n", | |
" tensor([9, 1, 3]))" | |
] | |
}, | |
"execution_count": 216, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_ds[[4,6,7]]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"...that means that we can actually skip the batch_sampler and collate_fn entirely:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 217, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(train_ds, sampler=train_samp)\n", | |
"valid_dl = DataLoader(valid_ds, sampler=valid_samp)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 218, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([1, 64, 784]), torch.Size([1, 64]))" | |
] | |
}, | |
"execution_count": 218, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb,yb = next(iter(train_dl))\n", | |
"xb.shape,yb.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Validation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"You **always** should also have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/), in order to identify if you are overfitting.\n", | |
"\n", | |
"We will calculate and print the validation loss at the end of each epoch.\n", | |
"\n", | |
"(Note that we always call `model.train()` before training, and `model.eval()` before inference, because these are used by layers such as `nn.BatchNorm2d` and `nn.Dropout` to ensure appropriate behaviour for these different phases.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fit(epochs, model, loss_func, opt, train_dl, valid_dl):\n", | |
" for epoch in range(epochs):\n", | |
" ...\n", | |
" print(epoch, tot_loss/count, tot_acc/count)\n", | |
" return tot_loss/count, tot_acc/count" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_dls(train_ds, valid_ds, bs, **kwargs):\n", | |
" return ... " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 107, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 0.2599395067691803 0.9152\n", | |
"1 0.1375726773455739 0.9599\n", | |
"2 0.10632649689242243 0.9696\n", | |
"3 0.11931819728948176 0.9643\n", | |
"4 0.15087997979782522 0.9549\n" | |
] | |
} | |
], | |
"source": [ | |
"train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)\n", | |
"model,opt = get_model()\n", | |
"loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.10" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": false, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment