Last active
November 25, 2022 11:19
-
-
Save PiotrCzapla/c84091e1cf51abcb469b8c21497ffa37 to your computer and use it in GitHub Desktop.
fast ai 03 backprop without body
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## The forward and backward passes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np\n", | |
"from pathlib import Path\n", | |
"from torch import tensor\n", | |
"from fastcore.test import test_close\n", | |
"torch.manual_seed(42)\n", | |
"\n", | |
"mpl.rcParams['image.cmap'] = 'gray'\n", | |
"torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)\n", | |
"np.set_printoptions(precision=2, linewidth=125)\n", | |
"\n", | |
"path_data = Path('data')\n", | |
"path_gz = path_data/'mnist.pkl.gz'\n", | |
"with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", | |
"x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Foundations version" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Basic architecture" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(50000, 784, tensor(10))" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"n,m = x_train.shape\n", | |
"c = y_train.max()+1\n", | |
"n,m,c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# num hidden\n", | |
"nh = 50" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"w1 = ...\n", | |
"b1 = ...\n", | |
"w2 = ...\n", | |
"b2 = ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def lin(x, w, b): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([10000, 50])" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"t = lin(x_valid, w1, b1)\n", | |
"t.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def relu(x): ...." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.00, 11.87, 0.00, ..., 5.48, 2.14, 15.30],\n", | |
" [ 5.38, 10.21, 0.00, ..., 0.88, 0.08, 20.23],\n", | |
" [ 3.31, 0.12, 3.10, ..., 16.89, 0.00, 24.74],\n", | |
" ...,\n", | |
" [ 4.01, 10.35, 0.00, ..., 0.23, 0.00, 18.28],\n", | |
" [10.62, 0.00, 10.72, ..., 0.00, 0.00, 18.23],\n", | |
" [ 2.84, 0.00, 1.43, ..., 0.00, 5.75, 2.12]])" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"t = relu(t)\n", | |
"t" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model(xb):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([10000, 1])" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"res = model(x_valid)\n", | |
"res.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Loss function: MSE" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"(Of course, `mse` is not a suitable loss function for multi-class classification; we'll use a better loss function soon. We'll use `mse` for now to keep things simple.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mse(output, targ): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(4308.76)" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mse(preds, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Gradients and backward pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 2 x$" | |
], | |
"text/plain": [ | |
"2*x" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sympy import symbols,diff\n", | |
"x,y = symbols('x y')\n", | |
"diff(x**2, x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 6 x$" | |
], | |
"text/plain": [ | |
"6*x" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"diff(3*x**2+9, x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def lin_grad(inp, out, w, b):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def forward_and_backward(inp, targ):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"forward_and_backward(x_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Save for testing against later\n", | |
"def get_grad(x): return ...\n", | |
"chks = w1,w2,b1,b2,x_train\n", | |
"grads = w1g,w2g,b1g,b2g,ig = [*map(get_grad, chks)]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We cheat a little bit and use PyTorch autograd to check our results." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mkgrad(x): return x.clone().requires_grad_(True)\n", | |
"ptgrads = w12,w22,b12,b22,xt2 = [*map(mkgrad, chks)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def forward(inp, targ):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss = forward(xt2, y_train)\n", | |
"loss.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for a,b in zip(grads, ptgrads): test_close(a.grad, b, eps=0.01)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Refactor model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Layers as classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Relu():\n", | |
" def __call__(self, inp):\n", | |
" ...\n", | |
" \n", | |
" def backward(self): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Lin():\n", | |
" def __init__(self, w, b): ...\n", | |
" \n", | |
" def __call__(self, inp):\n", | |
" ...\n", | |
"\n", | |
" def backward(self):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Mse():\n", | |
" def __call__(self, inp, targ):\n", | |
" ...\n", | |
" \n", | |
" def backward(self):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model():\n", | |
" def __init__(self, w1, b1, w2, b2):\n", | |
" ...\n", | |
" \n", | |
" def __call__(self, x, targ):\n", | |
" ...\n", | |
" \n", | |
" def backward(self):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Model(w1, b1, w2, b2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss = model(x_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_close(w2g, w2.g, eps=0.01)\n", | |
"test_close(b2g, b2.g, eps=0.01)\n", | |
"test_close(w1g, w1.g, eps=0.01)\n", | |
"test_close(b1g, b1.g, eps=0.01)\n", | |
"test_close(ig, x_train.g, eps=0.01)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Module.forward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Module():\n", | |
" def __call__(self, *args):\n", | |
" ...\n", | |
"\n", | |
" def forward(self): raise Exception('not implemented')\n", | |
" def backward(self): ...\n", | |
" def bwd(self): raise Exception('not implemented')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Relu(Module):\n", | |
" def forward(self, inp): ...\n", | |
" def bwd(self, out, inp): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Lin(Module):\n", | |
" def __init__(self, w, b): ...\n", | |
" def forward(self, inp): ...\n", | |
" def bwd(self, out, inp):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Mse(Module):\n", | |
" def forward (self, inp, targ): ... \n", | |
" def bwd(self, out, inp, targ): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Model(w1, b1, w2, b2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss = model(x_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_close(w2g, w2.g, eps=0.01)\n", | |
"test_close(b2g, b2.g, eps=0.01)\n", | |
"test_close(w1g, w1.g, eps=0.01)\n", | |
"test_close(b1g, b1.g, eps=0.01)\n", | |
"test_close(ig, x_train.g, eps=0.01)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Autograd" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch import nn\n", | |
"import torch.nn.functional as F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Linear(nn.Module):\n", | |
" def __init__(self, n_in, n_out):\n", | |
" ...\n", | |
" def forward(self, inp): ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(self, n_in, nh, n_out):\n", | |
" ...\n", | |
" \n", | |
" def __call__(self, x, targ):\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Model(m, nh, 1)\n", | |
"loss = model(x_train, y_train)\n", | |
"loss.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([-19.60, -2.40, -0.12, 1.99, 12.78, -15.32, -18.45, 0.35, 3.75, 14.67, 10.81, 12.20, -2.95, -28.33,\n", | |
" 0.76, 69.15, -21.86, 49.78, -7.08, 1.45, 25.20, 11.27, -18.15, -13.13, -17.69, -10.42, -0.13, -18.89,\n", | |
" -34.81, -0.84, 40.89, 4.45, 62.35, 31.70, 55.15, 45.13, 3.25, 12.75, 12.45, -1.41, 4.55, -6.02,\n", | |
" -62.51, -1.89, -1.41, 7.00, 0.49, 18.72, -4.84, -6.52])" | |
] | |
}, | |
"execution_count": null, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"l0 = model.layers[0]\n", | |
"l0.b.grad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.10.6 64-bit", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.10.6" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment