Skip to content

Instantly share code, notes, and snippets.

@PiotrCzapla
Last active November 25, 2022 11:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PiotrCzapla/00c82fb193c9ebc20702ea22de2cb737 to your computer and use it in GitHub Desktop.
Save PiotrCzapla/00c82fb193c9ebc20702ea22de2cb737 to your computer and use it in GitHub Desktop.
fastai 04_minibatch_training_withoutbody.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from torch import tensor,nn\n",
"import torch.nn.functional as F\n",
"from fastcore.test import test_close\n",
"\n",
"torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n",
"torch.manual_seed(1)\n",
"mpl.rcParams['image.cmap'] = 'gray'\n",
"\n",
"path_data = Path('data')\n",
"path_gz = path_data/'mnist.pkl.gz'\n",
"with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n",
"x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initial setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"### Data"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"n,m = x_train.shape\n",
"c = y_train.max()+1\n",
"nh = 50"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" super().__init__()\n",
" self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]\n",
" \n",
" def __call__(self, x):\n",
" for l in self.layers: x = l(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([50000, 10])"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Model(m, nh, 10)\n",
"pred = model(x_train)\n",
"pred.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"### Cross entropy loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"First, we will need to compute the softmax of our activations. This is defined by:\n",
"\n",
"$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \\cdots + e^{x_{n-1}}}$$\n",
"\n",
"or more concisely:\n",
"\n",
"$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{\\sum_{0 \\leq j \\leq n-1} e^{x_{j}}}$$ \n",
"\n",
"In practice, we will need the log of the softmax when we calculate the loss."
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def log_softmax(x): ..."
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n",
" [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n",
" [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n",
" ...,\n",
" [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n",
" [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n",
" [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<LogBackward0>)"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"log_softmax(pred)"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Note that the formula \n",
"\n",
"$$\\log \\left ( \\frac{a}{b} \\right ) = \\log(a) - \\log(b)$$ \n",
"\n",
"gives a simplification when we compute the log softmax, which was previously defined as `(x.exp()/(x.exp().sum(-1,keepdim=True))).log()`"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def log_softmax(x): ..."
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:\n",
"\n",
"$$\\log \\left ( \\sum_{j=1}^{n} e^{x_{j}} \\right ) = \\log \\left ( e^{a} \\sum_{j=1}^{n} e^{x_{j}-a} \\right ) = a + \\log \\left ( \\sum_{j=1}^{n} e^{x_{j}-a} \\right )$$\n",
"\n",
"where a is the maximum of the $x_{j}$."
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def logsumexp(x):\n",
" ... "
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us. "
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def log_softmax(x): ... # numerical stability achieved using logsumexp"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n",
" [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n",
" [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n",
" ...,\n",
" [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n",
" [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n",
" [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<SubBackward0>)"
]
},
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_close(logsumexp(pred), pred.logsumexp(-1))\n",
"sm_pred = log_softmax(pred)\n",
"sm_pred"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"The cross entropy loss for some target $x$ and some prediction $p(x)$ is given by:\n",
"\n",
"$$ -\\sum x\\, \\log p(x) $$\n",
"\n",
"But since our $x$s are 1-hot encoded, this can be rewritten as $-\\log(p_{i})$ where i is the index of the desired target.\n",
"\n",
"This can be done using numpy-style [integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#integer-array-indexing). Note that PyTorch supports all the tricks in the advanced indexing methods discussed in that link."
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([5, 0, 4])"
]
},
"execution_count": 131,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train[:3]"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(-2.20, grad_fn=<SelectBackward0>),\n",
" tensor(-2.37, grad_fn=<SelectBackward0>),\n",
" tensor(-2.36, grad_fn=<SelectBackward0>))"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm_pred[0,5],sm_pred[1,0],sm_pred[2,4]"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-2.20, -2.37, -2.36], grad_fn=<IndexBackward0>)"
]
},
"execution_count": 133,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm_pred[[0,1,2], y_train[:3]]"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def nll(input, target): ..."
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.30, grad_fn=<NegBackward0>)"
]
},
"execution_count": 135,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = nll(sm_pred, y_train)\n",
"loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Then use PyTorch's implementation."
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`."
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"test_close(F.cross_entropy(pred, y_train), loss, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"## Basic training loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Basically the training loop repeats over the following steps:\n",
"- get the output of the model on a batch of inputs\n",
"- compare the output to the labels we have and compute a loss\n",
"- calculate the gradients of the loss with respect to every parameter of the model\n",
"- update said parameters with those gradients to make them a little bit better"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"loss_func = F.cross_entropy"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([-0.09, -0.21, -0.08, 0.10, -0.04, 0.08, -0.04, -0.03, 0.01, 0.06], grad_fn=<SelectBackward0>),\n",
" torch.Size([64, 10]))"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bs=64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a mini-batch from x\n",
"preds = model(xb) # predictions\n",
"preds[0], preds.shape"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.30, grad_fn=<NllLossBackward0>)"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb = y_train[0:bs]\n",
"loss_func(preds, yb)"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3, 9, 3, 8, 5, 9, 3, 9, 3, 9, 5, 3, 9, 9, 3, 9, 9, 5, 8, 7, 9, 5, 3, 8, 9, 5, 9, 5, 5, 9, 3, 5, 9, 7, 5, 7, 9, 9, 3, 9, 3, 5, 3, 8,\n",
" 3, 5, 9, 5, 9, 5, 3, 9, 3, 8, 9, 5, 9, 5, 9, 5, 8, 8, 9, 8])"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.argmax(preds, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def accuracy(out, yb): ..."
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.09)"
]
},
"execution_count": 143,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy(preds, yb)"
]
},
{
"cell_type": "code",
"execution_count": 144,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate\n",
"epochs = 3 # how many epochs to train for"
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.3036487102508545 0.09375\n",
"0.12374822050333023 0.96875\n",
"0.09232541173696518 0.96875\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" for i in range(0, n, bs):\n",
" ... "
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"## Using parameters and optim"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true
},
"source": [
"### Parameters"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Use `nn.Module.__setattr__`:"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" ... \n",
" \n",
" def __call__(self, x): ..."
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model = Model(m, nh, 10)"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"l1: Linear(in_features=784, out_features=50, bias=True)\n",
"l2: Linear(in_features=50, out_features=10, bias=True)\n",
"relu: ReLU()\n"
]
}
],
"source": [
"for name,l in model.named_children(): print(f\"{name}: {l}\")"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"Model(\n",
" (l1): Linear(in_features=784, out_features=50, bias=True)\n",
" (l2): Linear(in_features=50, out_features=10, bias=True)\n",
" (relu): ReLU()\n",
")"
]
},
"execution_count": 149,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"Linear(in_features=784, out_features=50, bias=True)"
]
},
"execution_count": 150,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.l1"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for i in range(0, n, bs):\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.309433937072754 0.0625\n",
"0.20068740844726562 0.953125\n",
"0.18196897208690643 0.9375\n"
]
}
],
"source": [
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Behind the scenes, PyTorch overrides the `__setattr__` function in `nn.Module` so that the submodules you define are properly registered as parameters of the model."
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class DummyModule():\n",
" def __init__(self, n_in, nh, n_out):\n",
" ... \n",
" \n",
" def __setattr__(self,k,v):\n",
" ...\n",
" \n",
" def __repr__(self): return f'{self._modules}'\n",
" \n",
" def parameters(self):\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}"
]
},
"execution_count": 154,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mdl = DummyModule(m,nh,10)\n",
"mdl"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"[torch.Size([50, 784]),\n",
" torch.Size([50]),\n",
" torch.Size([10, 50]),\n",
" torch.Size([10])]"
]
},
"execution_count": 155,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[o.shape for o in mdl.parameters()]"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true
},
"source": [
"### Registering modules"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"We can use the original `layers` approach, but we have to register the modules."
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, layers):\n",
" ...\n",
" \n",
" def __call__(self, x):\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model = Model(layers)"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"Model(\n",
" (layer_0): Linear(in_features=784, out_features=50, bias=True)\n",
" (layer_1): ReLU()\n",
" (layer_2): Linear(in_features=50, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 159,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true
},
"source": [
"### nn.ModuleList"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"`nn.ModuleList` does this for us."
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class SequentialModel(nn.Module):\n",
" def __init__(self, layers):\n",
" ...\n",
" \n",
" def __call__(self, x):\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model = SequentialModel(layers)"
]
},
{
"cell_type": "code",
"execution_count": 162,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"SequentialModel(\n",
" (layers): ModuleList(\n",
" (0): Linear(in_features=784, out_features=50, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=50, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.3222596645355225 0.03125\n",
"0.14386768639087677 0.96875\n",
"0.08797654509544373 0.96875\n"
]
},
{
"data": {
"text/plain": [
"(tensor(0.02, grad_fn=<NllLossBackward0>), tensor(1.))"
]
},
"execution_count": 163,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit()\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true
},
"source": [
"### nn.Sequential"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"`nn.Sequential` is a convenient class which does the same as the above:"
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.3087124824523926 0.09375\n",
"0.20316630601882935 0.90625\n",
"0.20330585539340973 0.921875\n"
]
},
{
"data": {
"text/plain": [
"(tensor(0.01, grad_fn=<NllLossBackward0>), tensor(1.))"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit()\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Linear(in_features=784, out_features=50, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=50, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true
},
"source": [
"### optim"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"Let's replace our previous manually coded optimization step:\n",
"\n",
"```python\n",
"with torch.no_grad():\n",
" for p in model.parameters(): p -= p.grad * lr\n",
" model.zero_grad()\n",
"```\n",
"\n",
"and instead use just:\n",
"\n",
"```python\n",
"opt.step()\n",
"opt.zero_grad()\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class Optimizer():\n",
" def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr\n",
" \n",
" def step(self):\n",
" ...\n",
"\n",
" def zero_grad(self):\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"opt = Optimizer(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.30017352104187 0.0625\n",
"0.13068024814128876 0.96875\n",
"0.11748667806386948 0.96875\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" for i in range(0, n, bs):\n",
" ..."
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"PyTorch already provides this exact functionality in `optim.SGD` (it also handles stuff like momentum, which we'll look at later)"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"from torch import optim"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def get_model():\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.30, grad_fn=<NllLossBackward0>)"
]
},
"execution_count": 173,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model,opt = get_model()\n",
"loss_func(model(xb), yb)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.312685012817383 0.078125\n",
"0.21422098577022552 0.90625\n",
"0.17829009890556335 0.921875\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" for i in range(0, n, bs):\n",
" ..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset and DataLoader"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"### Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true
},
"source": [
"It's clunky to iterate through minibatches of x and y values separately:\n",
"\n",
"```python\n",
" xb = x_train[s]\n",
" yb = y_train[s]\n",
"```\n",
"\n",
"Instead, let's do these two steps together, by introducing a `Dataset` class:\n",
"\n",
"```python\n",
" xb,yb = train_ds[s]\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"class Dataset():\n",
" def __init__(self, x, y): ...\n",
" def __len__(self): return ...\n",
" def __getitem__(self, i): return ..."
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n",
"assert len(train_ds)==len(x_train)\n",
"assert len(valid_ds)==len(x_valid)"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]),\n",
" tensor([5, 0, 4, 1, 9]))"
]
},
"execution_count": 177,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = train_ds[0:5]\n",
"assert xb.shape==(5,28*28)\n",
"assert yb.shape==(5,)\n",
"xb,yb"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"model,opt = get_model()"
]
},
{
"cell_type": "code",
"execution_count": 179,
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.2979648113250732 0.09375\n",
"0.21198976039886475 0.953125\n",
"0.17290501296520233 0.921875\n"
]
}
],
"source": [
"for epoch in range(epochs):\n",
" for i in range(0, n, bs):\n",
" ..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### DataLoader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Previously, our loop iterated over batches (xb, yb) like this:\n",
"\n",
"```python\n",
"for i in range(0, n, bs):\n",
" xb,yb = train_ds[i:min(n,i+bs)]\n",
" ...\n",
"```\n",
"\n",
"Let's make our loop much cleaner, using a data loader:\n",
"\n",
"```python\n",
"for xb,yb in train_dl:\n",
" ...\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 180,
"metadata": {},
"outputs": [],
"source": [
"class DataLoader():\n",
" def __init__(self, ds, bs): self.ds,self.bs = ds,bs\n",
" def __iter__(self): ..."
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, bs)\n",
"valid_dl = DataLoader(valid_ds, bs)"
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 784])"
]
},
"execution_count": 182,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = next(iter(valid_dl))\n",
"xb.shape"
]
},
{
"cell_type": "code",
"execution_count": 183,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3, 8, 6, 9, 6, 4, 5, 3, 8, 4, 5, 2, 3, 8, 4, 8, 1, 5, 0, 5, 9, 7, 4, 1, 0, 3, 0, 6, 2, 9, 9, 4, 1, 3, 6, 8, 0, 7, 7, 6, 8, 9, 0, 3,\n",
" 8, 3, 7, 7, 8, 4, 4, 1, 2, 9, 8, 1, 1, 0, 6, 6, 5, 0, 1, 1])"
]
},
"execution_count": 183,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb"
]
},
{
"cell_type": "code",
"execution_count": 184,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3)"
]
},
"execution_count": 184,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(xb[0].view(28,28))\n",
"yb[0]"
]
},
{
"cell_type": "code",
"execution_count": 185,
"metadata": {},
"outputs": [],
"source": [
"model,opt = get_model()"
]
},
{
"cell_type": "code",
"execution_count": 186,
"metadata": {},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for xb,yb in train_dl:\n",
" ... "
]
},
{
"cell_type": "code",
"execution_count": 187,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(1.))"
]
},
"execution_count": 187,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit()\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random sampling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized."
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [],
"source": [
"import random"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [],
"source": [
"class Sampler():\n",
" def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle\n",
" def __iter__(self):\n",
" ... "
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [],
"source": [
"from itertools import islice"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [],
"source": [
"ss = Sampler(train_ds)"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"1\n",
"2\n",
"3\n",
"4\n"
]
}
],
"source": [
"it = iter(ss)\n",
"for o in range(5): print(next(it))"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3, 4]"
]
},
"execution_count": 193,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(islice(ss, 5))"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2468, 34785, 22293, 22313, 36680]"
]
},
"execution_count": 194,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ss = Sampler(train_ds, shuffle=True)\n",
"list(islice(ss, 5))"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [],
"source": [
"import fastcore.all as fc"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [],
"source": [
"class BatchSampler():\n",
" def __init__(self, sampler, bs, drop_last=False): fc.store_attr()\n",
" def __iter__(self): ... "
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[3338, 3713, 47999, 33349],\n",
" [1382, 28497, 19584, 35095],\n",
" [33760, 20524, 1959, 7968],\n",
" [40952, 25061, 32207, 20443],\n",
" [11419, 11479, 45286, 40070]]"
]
},
"execution_count": 197,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batchs = BatchSampler(ss, 4)\n",
"list(islice(batchs, 5))"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [],
"source": [
"def collate(b):\n",
" ... "
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [],
"source": [
"class DataLoader():\n",
" def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr()\n",
" def __iter__(self): ... "
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [],
"source": [
"train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs)\n",
"valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs)"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate)\n",
"valid_dl = DataLoader(valid_ds, batchs=valid_samp, collate_fn=collate)"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3)"
]
},
"execution_count": 202,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xb,yb = next(iter(valid_dl))\n",
"plt.imshow(xb[0].view(28,28))\n",
"yb[0]"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 784]), torch.Size([64]))"
]
},
"execution_count": 203,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.04, grad_fn=<NllLossBackward0>), tensor(1.))"
]
},
"execution_count": 204,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model,opt = get_model()\n",
"fit()\n",
"\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multiprocessing DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [],
"source": [
"import torch.multiprocessing as mp\n",
"from fastcore.basics import store_attr"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [],
"source": [
"class DataLoader():\n",
" def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr()\n",
" def __iter__(self):\n",
" ... "
]
},
{
"cell_type": "code",
"execution_count": 207,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate, n_workers=2)\n",
"it = iter(train_dl)"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 784]), torch.Size([64]))"
]
},
"execution_count": 208,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = next(it)\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PyTorch DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [],
"source": [
"train_samp = BatchSampler(RandomSampler(train_ds), bs, drop_last=False)\n",
"valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False)"
]
},
{
"cell_type": "code",
"execution_count": 211,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate)\n",
"valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate)"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.09, grad_fn=<NllLossBackward0>), tensor(0.98))"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model,opt = get_model()\n",
"fit()\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTorch can auto-generate the BatchSampler for us:"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)\n",
"valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTorch can also generate the Sequential/RandomSamplers too:"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2)\n",
"valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(0.98))"
]
},
"execution_count": 215,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model,opt = get_model()\n",
"fit()\n",
"\n",
"loss_func(model(xb), yb), accuracy(model(xb), yb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our dataset actually already knows how to sample a batch of indices all at once:"
]
},
{
"cell_type": "code",
"execution_count": 216,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]),\n",
" tensor([9, 1, 3]))"
]
},
"execution_count": 216,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ds[[4,6,7]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...that means that we can actually skip the batch_sampler and collate_fn entirely:"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(train_ds, sampler=train_samp)\n",
"valid_dl = DataLoader(valid_ds, sampler=valid_samp)"
]
},
{
"cell_type": "code",
"execution_count": 218,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 64, 784]), torch.Size([1, 64]))"
]
},
"execution_count": 218,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = next(iter(train_dl))\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You **always** should also have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/), in order to identify if you are overfitting.\n",
"\n",
"We will calculate and print the validation loss at the end of each epoch.\n",
"\n",
"(Note that we always call `model.train()` before training, and `model.eval()` before inference, because these are used by layers such as `nn.BatchNorm2d` and `nn.Dropout` to ensure appropriate behaviour for these different phases.)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"def fit(epochs, model, loss_func, opt, train_dl, valid_dl):\n",
" for epoch in range(epochs):\n",
" ...\n",
" print(epoch, tot_loss/count, tot_acc/count)\n",
" return tot_loss/count, tot_acc/count"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
"def get_dls(train_ds, valid_ds, bs, **kwargs):\n",
" return ... "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.2599395067691803 0.9152\n",
"1 0.1375726773455739 0.9599\n",
"2 0.10632649689242243 0.9696\n",
"3 0.11931819728948176 0.9643\n",
"4 0.15087997979782522 0.9549\n"
]
}
],
"source": [
"train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)\n",
"model,opt = get_model()\n",
"loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment