Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save anand086/6146185ededa0f4834eb77e03e89de18 to your computer and use it in GitHub Desktop.
Save anand086/6146185ededa0f4834eb77e03e89de18 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from fastai.vision.all import *\n",
"from fastbook import *\n",
"\n",
"matplotlib.rc('image', cmap='Greys')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The MNIST Loss Function"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Untar MNIST dataset provided by Fastai\n",
"path = untar_data(URLs.MNIST_SAMPLE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#3) [Path('/home/ec2-user/.fastai/data/mnist_sample/train'),Path('/home/ec2-user/.fastai/data/mnist_sample/labels.csv'),Path('/home/ec2-user/.fastai/data/mnist_sample/valid')]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('/home/ec2-user/.fastai/data/mnist_sample/train/7'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/3')]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(path/'train').ls()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#6265) [Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10002.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/1001.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10014.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10019.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10039.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10046.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10050.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10063.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10077.png'),Path('/home/ec2-user/.fastai/data/mnist_sample/train/7/10086.png')...]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threes = (path/'train'/'3').ls().sorted()\n",
"sevens = (path/'train'/'7').ls().sorted()\n",
"sevens"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6131, 6265)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seven_tensors = [tensor(Image.open(o)) for o in sevens] #6265\n",
"three_tensors = [tensor(Image.open(o)) for o in threes] #6131\n",
"len(three_tensors),len(seven_tensors)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([6131, 28, 28])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stacked_sevens = torch.stack(seven_tensors).float()/255\n",
"stacked_threes = torch.stack(three_tensors).float()/255\n",
"stacked_threes.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stacked_threes.ndim"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([6131, 784])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stacked_threes.view(-1, 28*28).shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stacked_threes.view(-1, 28*28).ndim"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# coverting from 3-D to 2-D. \"-1\" is a special parameter to view that means \n",
"# \"make this axis as big as necessary to fot all the data\"\n",
"train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12396, 784])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_x.shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([12396, 784]), torch.Size([12396, 1]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# label the data 1 for 3s and 0 for 7s\n",
"train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n",
"train_x.shape,train_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# A Dataset in PyTorch is required to return a tuple of (x,y) when indexed.\n",
"dset = list(zip(train_x,train_y))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12396"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dset)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([784]), tensor([1]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# shape of 0th element of the dataset which in this case in a number 3\n",
"# as expected its size is 784 and label 1\n",
"x,y = dset[0]\n",
"x.shape,y"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid_3_tens = torch.stack([tensor(Image.open(o)) \n",
" for o in (path/'valid'/'3').ls()])\n",
"valid_3_tens = valid_3_tens.float()/255\n",
"valid_7_tens = torch.stack([tensor(Image.open(o)) \n",
" for o in (path/'valid'/'7').ls()])\n",
"valid_7_tens = valid_7_tens.float()/255\n",
"valid_3_tens.shape,valid_7_tens.shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# create validation dataset\n",
"valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n",
"valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n",
"valid_dset = list(zip(valid_x,valid_y))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# lets define a function to initialize random weights for every PIXEL in an image which is 784 pixels\n",
"def init_params(size, std=1.0): \n",
" return (torch.randn(size)*std).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"weights = init_params((28*28,1))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"784"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(weights)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# the function weights*pixels can be 0 when pixel is equal to 0 (i,e. intercept is 0). \n",
"# y = w*x+b, so lets add some bias b\n",
"# weights and bias makes up the parameters in neural network"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"bias = init_params(1)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.3472], requires_grad=True)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bias"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-6.2330], grad_fn=<AddBackward0>)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# lets calculate the prediction for 1 image - in this case for an image of 3\n",
"(train_x[0]*weights.T).sum() + bias"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# define a linear function \n",
"def linear1(xb): \n",
" return xb@weights + bias"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ -6.2330],\n",
" [-10.6388],\n",
" [-20.8865],\n",
" ...,\n",
" [-15.9176],\n",
" [ -1.6866],\n",
" [-11.3568]], grad_fn=<AddBackward0>)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# perform precitions \n",
"preds = linear1(train_x)\n",
"preds"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5379961133003235"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate accuracy of the model\n",
"corrects = (preds>0.0).float() == train_y\n",
"#corrects\n",
"corrects.float().mean().item()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# We need gradients in order to improve the model using SGD, and in order to calculate gradients we need\n",
"# a loss function that represents how good our model is. The Gradients are a measure of how that \n",
"# loss function changes with small tweaks to the weights."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# At this point you might think why not choose model accuracy as our loss function? In this case , \n",
"# we would calculate our prediction for each imeage, collect these values to calculate an overall accuracy\n",
"# and then calculate the gradients of each weughts with respect to that overall accuracy."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# Well, we have a signifiact technical problem. The gradient of a finction is its slope, which is also defined\n",
"# as rise over run. Mathemtically - (y_new - y_old)/(x_new - x_old)\n",
"# This gives a good approximation of the gradient when x_new is very similar to x_old meaning that thier difference\n",
"# is very small. But accuracy changes only when prediction changes from 3 to a 7 or vise versa. The problem is that\n",
"# a small change in weights from x_old to x_new isn't likely to cause any prediction to change, so (y_new - y_old)\n",
"# will almost always be 0. So gradient is 0 everywhere. A very small change on the value of a weight will often \n",
"# not change the accuracy at all. This mean it is not useful to use accuracy as a loss function."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# example data\n",
"trgts = tensor([1,0,1])\n",
"prds = tensor([0.9, 0.4, 0.2])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# torch.where(a, b, c) means [b[i] if a[i] else c[i] for i in range(len(a))]\n",
"# how distant each prediction is from 1 if it should be 1 and how distant it is from 0 if it should be 0\n",
"# then take the mean \n",
"def mnist_loss(predictions, targets):\n",
" return torch.where(targets==1, 1-predictions, predictions).mean()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1000, 0.4000, 0.8000])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.where(trgts==1, 1-prds, prds)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.9000, 0.6000, 0.2000])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.where(trgts==0, 1-prds, prds)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.4333)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_loss(prds,trgts)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.2333)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_loss(tensor([0.9, 0.4, 0.8]),trgts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sigmoid"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"# return probability between 0 and 1\n",
"def sigmoid(x): return 1/(1+torch.exp(-x))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/SageMaker/.env/fastai/lib/python3.6/site-packages/fastbook/__init__.py:73: UserWarning: Not providing a value for linspace's steps is deprecated and will throw a runtime error in a future release. This warning will appear only once per process. (Triggered internally at /pytorch/aten/src/ATen/native/RangeFactories.cpp:25.)\n",
" x = torch.linspace(min,max)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAi4klEQVR4nO3deXRb9Z338fdX8pbN2eysdjYSIGEPBlJogRYISdphaQsN0w3aKW1nmGmfdmYOTHuYHtp5TkvP9JnplE5Lp1DaAoEutBkIkLCVpQQSGrI5CTHZbGexszlxvEr6Pn9IBmFkIjuyryR/Xufo6Oren3S/vrr++Pp3N3N3REQk94WCLkBERDJDgS4ikicU6CIieUKBLiKSJxToIiJ5QoEuIpInFOiSd8zsk2a2PNvma2bPmdnfDGRNMrgo0CVnmdn7zezPZtZkZgfN7CUzO8/d73f3+QNdT1DzFelSEHQBIn1hZqXAo8CXgYeBIuADQHuQdYkESVvokqtOBnD3B9096u6t7r7c3deZ2Y1m9mJXQzObb2ZbElvyPzazP3V1fSTavmRm/8/MDpvZNjO7MDG+1swazOyzSZ810sx+aWaNZrbTzL5pZqGkz0qe7xVmtjkx3x8BNmBLRwYlBbrkqjeAqJndZ2YLzWx0qkZmVgb8FrgNGAtsAS7s1uwCYF1i+gPAEuA8YCbwKeBHZjY80fa/gJHADOAS4DPATT3M9/fAN4Ey4E3gor7+sCLpUKBLTnL3I8D7AQd+BjSa2VIzG9+t6SJgo7v/3t0jwA+Bvd3abHf3e909CjwEVAJ3uHu7uy8HOoCZZhYGFgO3uftRd98B/Dvw6RQlds33t+7eCfxHivmKZJQCXXKWu29y9xvdvQI4HZhEPDiTTQJqk97jQF23NvuShlsT7bqPG058S7sQ2Jk0bScwOUV5qeZbm6KdSMYo0CUvuPtm4BfEgz3ZHqCi64WZWfLrXtoPdAJTk8ZNAepTtN1DfEs/eb6VKdqJZIwCXXKSmZ1qZl83s4rE60rgBmBlt6aPAWeY2TVmVgD8HTChL/NMdMk8DPybmY0ws6nA14Bfp2j+GHCamX00Md9/6Ot8RdKlQJdcdZT4zsxXzOwY8SDfAHw9uZG77weuA+4EDgBzgNX0/fDGvweOAduAF4nvRL2ne6Ok+X43Md9ZwEt9nKdIWkw3uJDBJHGIYR3wSXd/Nuh6RDJJW+iS98zsSjMbZWbFwL8QPx68e9eMSM5ToMtg8D7ix4HvB/4KuMbdW4MtSSTz1OUiIpIntIUuIpInArs4V1lZmU+bNi2o2YuI5KTXXnttv7uXp5oWWKBPmzaN1atXBzV7EZGcZGY7e5qmLhcRkTyhQBcRyRMKdBGRPKFAFxHJEwp0EZE8cdxAN7N7Erfh2tDDdDOzH5pZjZmtM7O5mS9TRESOJ50t9F8AC95j+kLiV5KbBdwM/PeJlyUiIr113OPQ3f15M5v2Hk2uBn6ZuCPLysRFkCa6+55MFSki+cfdicSc9kiMjkiM9kiUzojTEY3SEXEisRidUScSjRGNOZ0xJxqLEY3x9rM7sZgTcycac9wh5k4s8ezvGI4/x+edGJcYBoi/evt1V41vT3932+7t3/HzvfOHfce0y2aP56zKUX1abu8lEycWTeadt9aqS4x7V6Cb2c3Et+KZMmVKBmYtIkGJRGMcONbB/uZ2Dh7r4OCxDg4d66CpNUJTaydH2zppbo9wtC1Cc3uE1o4oLZ3x59aOKG2ReFAPFmZvD48rLcnaQE+bu98N3A1QVVU1eL5JkRwUjTm7D7eybf8xtjc2U3uolfpDrdQfbmXvkTYONLfTUx4PKwozckghw0sKGF5cwIiSAsaXFjOsqICSojBDCuOP4oIQxYUhigvCFIZDFBXEH4UhozAcoiCceA4ZBWEjHAoRNiMc6npAyIxQYlwoZBgQDhlmYMTHG/FANesaH39fV5vksKWrLZY03DXekoaT23f7gIBkItDreee9EitIfY9FEclS7ZEoG+qPsLb2MJv2HGHz3qO8se8o7ZHYW21KCkNMHjWEyaOHMmdiKeNLiykvLaF8eBFjhhUzZlgRo4cWUjqkkMKwDqALQiYCfSlwi5ktIX5LsCb1n4tkt/ZIlNd2HuLFrft5edsBNtYfoSMaD++y4cXMnjiCT8+bysxxw5leNozp5cMoH16cNVuiktpxA93MHgQuBcrMrA74V6AQwN1/AiwDFgE1QAtwU38VKyJ9d6Stk2c2NbBs/R6e39pIW2eMcMg4u3IUN100jXOmjGbulFGMKy0JulTpo3SOcrnhONOd+J3URSTLxGLOCzX7eWjVLp6qbqAjGmNCaQnXV1Vy8axyLpgxhhElhUGXKRkS2OVzRaT/NLdHuH/lTn758k7qD7cyemghn5o3lQ+fOZFzKkcRCqnrJB8p0EXyyOGWDu55aQf3/XkHTa2dzJsxhlsXnsr808ZTXBAOujzpZwp0kTzQGY3x65U7+Y+nttLU2sn8OeP52w/O5Ox+ONZZspcCXSTHvbC1kX9dupFtjcd4/8wyvvHh2cyeWBp0WRIABbpIjmrpiPB/l23i1yt3Mb1sGD//bBUfOnWcDi0cxBToIjnotZ2H+NrDr7PrYAt/8/7p/OOVp1BSqD7ywU6BLpJjHnx1F7f/cQPjS0t48AvzmDdjbNAlSZZQoIvkiM5ojO88Ws19L+/k4pPL+a8bzmHkEB1DLm9ToIvkgJaOCF/81Wu8sHU/X/jAdG5dOJuwjiWXbhToIlmuuT3C5+5dxeqdB7nzY2dy/XmVx3+TDEoKdJEs1tTayY33vsq6uiZ+eMM5fOTMSUGXJFlMgS6SpY61R/jMz1+hes8RfvzJuVx52oSgS5Isp0AXyUKRaIy/f3AN6+ub+Omnq7hizvigS5IcoEAXyTLuzr8u3cgzmxv4t2tPV5hL2nRbEZEs89Pnt3H/K7v40iUn8ckLpgZdjuQQBbpIFnlhayPfe2IzHzlzIv985SlBlyM5RoEukiUajrTxfx56nZnlw/n+x8/SNcul19SHLpIFojHnK0tep7k9wgNfmMeQIl2XRXpPgS6SBX70TA0vbzvAnR87k5PHjwi6HMlR6nIRCdjrtYf5z6ff4JqzJ3FdVUXQ5UgOU6CLBKgzGuPW362jfEQxd1xzuq5lLidEXS4iAbr7+W1s3nuUuz99LqUlunKinBhtoYsEZFtjM//59FYWnTGB+TqtXzJAgS4SAHfntt+vp6QgxLeuOi3ociRPKNBFArB07W5e2X6Qf1k0m3EjSoIuR/KEAl1kgLV1RrnziS3MmVjK9VW6trlkjgJdZID94s87qD/cyjc/PFtng0pGKdBFBtCB5nbueqaGy04dx4Uzy4IuR/KMAl1kAP3w6a20dEa5bdGpQZcieUiBLjJAtu8/xv2v7GLxeZXMHKfT+yXzFOgiA+SuZ2sIh4yvXD4r6FIkT6UV6Ga2wMy2mFmNmd2aYvoUM3vWzNaY2TozW5T5UkVyV+3BFh5ZU89fXzBFhylKvzluoJtZGLgLWAjMAW4wszndmn0TeNjdzwEWAz/OdKEiuezHz9UQNuOLF58UdCmSx9LZQj8fqHH3be7eASwBru7WxoHSxPBIYHfmShTJbfWHW/nta3Vcf14FE0Zq61z6TzqBPhmoTXpdlxiX7FvAp8ysDlgG/H2qDzKzm81stZmtbmxs7EO5Irnnp396E3f40iXaOpf+lamdojcAv3D3CmAR8Csze9dnu/vd7l7l7lXl5eUZmrVI9mo40saSVbV8/NwKKkYPDbocyXPpBHo9kHx+ckViXLLPAw8DuPvLQAmgsyZk0Lvv5R10RmN8+VJtnUv/SyfQVwGzzGy6mRUR3+m5tFubXcBlAGY2m3igq09FBrW2zigPvLKLK2aPZ+rYYUGXI4PAcQPd3SPALcCTwCbiR7NsNLM7zOyqRLOvA18ws7XAg8CN7u79VbRILvjDmnoOtXRy00XTgy5FBom07ljk7suI7+xMHnd70nA1cFFmSxPJXe7OPS9tZ/bEUubNGBN0OTJI6ExRkX7w5zcP8Ma+Zm66aJruEyoDRoEu0g/ueXE7Y4cVcdVZk4IuRQYRBbpIhu3Yf4xntjTwyQumUFIYDrocGUQU6CIZ9sCruwib8al5U4MuRQYZBbpIBnVEYvzutToumz2OcaU6zV8GlgJdJINWVO/jwLEOFp8/JehSZBBSoItk0JJVu5g8aggXz9KlLWTgKdBFMqT2YAsvbN3PdVUVhHXzZwmAAl0kQx5eXYsZXF9VefzGIv1AgS6SAZFojIdX13LJyeVMGjUk6HJkkFKgi2TAn95oZN+Rdhafp52hEhwFukgG/O4vdYwdVsRls8cFXYoMYgp0kRPU1NrJU5sa+KuzJlEY1q+UBEdrn8gJenz9HjoiMa49p/udGUUGlgJd5AQ9sqaeGWXDOLNiZNClyCCnQBc5AXWHWnhl+0GuPWeyLpMrgVOgi5yAP76+G4Br1N0iWUCBLtJH7s4ja+o5b9poKscMDbocEQW6SF9t3H2EmoZmbZ1L1lCgi/TRH9bUUxg2PnzGxKBLEQEU6CJ9Eos5j63fw8Wzyhk1tCjockQABbpIn6ypPcSepjY+cpa2ziV7KNBF+uDRdXsoKghx+ezxQZci8hYFukgvxWLOsvV7uOTkckaUFAZdjshbFOgivbR65yH2HWnnI2equ0WyiwJdpJceW7eb4oIQl6m7RbKMAl2kF6IxZ9mGvXzo1HEMLy4IuhyRd1Cgi/TCq9sP0ni0nQ+ru0WykAJdpBceW7+bksIQHzpVN7KQ7KNAF0lTLOY8uXEfHzxlHEOL1N0i2SetQDezBWa2xcxqzOzWHtpcb2bVZrbRzB7IbJkiwVtTe4jGo+0sOH1C0KWIpHTczQwzCwN3AVcAdcAqM1vq7tVJbWYBtwEXufshM9P/o5J3ntiwl6Kwulske6WzhX4+UOPu29y9A1gCXN2tzReAu9z9EIC7N2S2TJFguTtPbNzLRTPH6mQiyVrpBPpkoDbpdV1iXLKTgZPN7CUzW2lmC1J9kJndbGarzWx1Y2Nj3yoWCUD1niPUHmxVd4tktUztFC0AZgGXAjcAPzOzUd0bufvd7l7l7lXl5eUZmrVI/3tiw15Chq7dIlktnUCvByqTXlckxiWrA5a6e6e7bwfeIB7wInnhiQ17uWD6WMYOLw66FJEepRPoq4BZZjbdzIqAxcDSbm3+QHzrHDMrI94Fsy1zZYoEp6ahma0Nzepukax33EB39whwC/AksAl42N03mtkdZnZVotmTwAEzqwaeBf7J3Q/0V9EiA+nJjXsBmH+aulsku6V1doS7LwOWdRt3e9KwA19LPETyyvKNezmrchQTRw4JuhSR96QzRUXew96mNtbWNXGlts4lByjQRd7Dik37AJg/R4Eu2U+BLvIelm/cy4yyYZxUPjzoUkSOS4Eu0oMjbZ2s3HaAK+aMx8yCLkfkuBToIj14bksjnVHX0S2SMxToIj1YvnEvZcOLObtydNCliKRFgS6SQnskynNbGrl89jjCIXW3SG5QoIuksHLbQZrbI+pukZyiQBdJYfnGvQwtCnPhSWVBlyKSNgW6SDexmLOieh+XnFxOSWE46HJE0qZAF+lmfX0TDUfbuUInE0mOUaCLdLOieh/hkOlWc5JzFOgi3ayo3sd500YzamhR0KWI9IoCXSTJrgMtbNl3lCvm6NrnknsU6CJJllcnrn2u/nPJQQp0kSQrqvdx6oQRVI4ZGnQpIr2mQBdJOHSsg1U7DuroFslZCnSRhGc2NxBzFOiSsxToIgkrqvcxobSEMyaPDLoUkT5RoIsAbZ1Rnt/ayOVzxuna55KzFOgiwJ/f3E9LR5T5OlxRcpgCXQRYvnEfI4oLmDdjbNCliPSZAl0GvWjMeWrTPi49dRxFBfqVkNyltVcGvTW7DrG/uUMnE0nOU6DLoLeieh+FYePSU8qDLkXkhCjQZVBzd57cuJf3nVTGiJLCoMsROSEKdBnUahqa2XGgRd0tkhcU6DKoLa/eB+jsUMkPCnQZ1JZX7+OsylGMLy0JuhSRE6ZAl0Fr9+FW1tYeVneL5A0FugxayzfGr32+8HSdHSr5Ia1AN7MFZrbFzGrM7Nb3aPcxM3Mzq8pciSL944mNezl5/HBmlA8PuhSRjDhuoJtZGLgLWAjMAW4wszkp2o0AvgK8kukiRTLtQHM7r24/yILTtHUu+SOdLfTzgRp33+buHcAS4OoU7b4NfA9oy2B9Iv3iqU37iDlcqe4WySPpBPpkoDbpdV1i3FvMbC5Q6e6PvdcHmdnNZrbazFY3Njb2uliRTHliw16mjBnKnImlQZcikjEnvFPUzELAD4CvH6+tu9/t7lXuXlVertOsJRhH2jp5sWY/C06foGufS15JJ9Drgcqk1xWJcV1GAKcDz5nZDmAesFQ7RiVbPbu5gc6oc6X6zyXPpBPoq4BZZjbdzIqAxcDSronu3uTuZe4+zd2nASuBq9x9db9ULHKCntiwl3EjijmnclTQpYhk1HED3d0jwC3Ak8Am4GF332hmd5jZVf1doEgmtXREeG5LI1eeNoFQSN0tkl8K0mnk7suAZd3G3d5D20tPvCyR/vHs5kZaO6MsOmNi0KWIZJzOFJVB5dF1uykfUcz508cEXYpIxinQZdA41h7hmc0NLDp9AmF1t0geUqDLoPH05gbaIzE+fOakoEsR6RcKdBk0Hl27m/GlxVRNHR10KSL9QoEug8LRtk6ee6ORRWdM1NEtkrcU6DIoPL2pgY5IjI+cqaNbJH8p0GVQeHTdbiaNLOGcSnW3SP5SoEvea2rp5Pk39rNQ3S2S5xTokveWbdhDRzTGNWdPPn5jkRymQJe898hf6jmpfBinT9alciW/KdAlr9UebOHVHQf56NwKXSpX8p4CXfLaH1+PX+n5qrN0MpHkPwW65C1355E19Zw/bQyVY4YGXY5Iv1OgS95aX9/Em43HuHaudobK4KBAl7z1yJp6isIhFp2uk4lkcFCgS16KRGP879rdXDZ7HCOHFgZdjsiAUKBLXnpmcwP7mzu49hx1t8jgoUCXvPTQqlrKRxTzwVPHBV2KyIBRoEve2dPUyrNbGrju3AoKw1rFZfDQ2i555zer64g5fOK8yqBLERlQCnTJK7GY89CqWi6aOZapY4cFXY7IgFKgS155oWY/9YdbWXzelKBLERlwCnTJKw+t2sXooYXMP2180KWIDDgFuuSNxqPtrKjex0fnVlBcEA66HJEBp0CXvPHAK7vojDp/fYG6W2RwUqBLXuiIxPj1Kzu59JRyTiofHnQ5IoFQoEteeGz9bhqPtnPTRdODLkUkMAp0yXnuzj0v7mDmuOFcPKss6HJEAqNAl5z32s5DrK9v4sYLp+muRDKoKdAl59370g5GDinko7ruuQxyaQW6mS0wsy1mVmNmt6aY/jUzqzazdWb2tJlNzXypIu9Wf7iVJzbuZfH5lQwtKgi6HJFAHTfQzSwM3AUsBOYAN5jZnG7N1gBV7n4m8FvgzkwXKpLKT//0JiGDz75vWtCliAQunS3084Ead9/m7h3AEuDq5Abu/qy7tyRergQqMlumyLvtO9LGklW1fPzcCiaNGhJ0OSKBSyfQJwO1Sa/rEuN68nng8VQTzOxmM1ttZqsbGxvTr1IkhZ/+aRvRmPPlS2YGXYpIVsjoTlEz+xRQBXw/1XR3v9vdq9y9qry8PJOzlkFmf3M7D7y6k2vOnsyUsUODLkckK6SzF6keSL6wdEVi3DuY2eXAN4BL3L09M+WJpPazF7bREYnxdx88KehSRLJGOlvoq4BZZjbdzIqAxcDS5AZmdg7wU+Aqd2/IfJkibzt0rINfvbyTvzprEjN0mr/IW44b6O4eAW4BngQ2AQ+7+0Yzu8PMrko0+z4wHPiNmb1uZkt7+DiRE3bXszW0dka55YPqOxdJltaBu+6+DFjWbdztScOXZ7gukZR2HjjGfS/v4PpzK5k1fkTQ5YhkFZ0pKjnlzie2UBAK8bX5JwddikjWUaBLznht50EeW7+HL14yg/GlJUGXI5J1FOiSE9yd7zy2iXEjirn54hlBlyOSlRTokhOWrt3Nml2H+cf5p+iaLSI9UKBL1jvc0sG3H63mzIqRfOxcXVVCpCfa1JGs92+PbeJQSye//NwFhEO63rlIT7SFLlntxa37+c1rdXzx4hnMmVQadDkiWU2BLlmrtSPKvzyynullw/iHy2YFXY5I1lOXi2St7z6+iV0HW1hy8zxKCsNBlyOS9bSFLlnpiQ17uO/lnXzuounMmzE26HJEcoICXbJO7cEW/um36zirYiS3Ljw16HJEcoYCXbJKRyTGLQ+uAeBHfz2XogKtoiLpUh+6ZA1359uPVrO29jD//cm5VI7RjStEekObP5I1fv7idn61cic3XzyDhWdMDLockZyjQJessGz9Hr7z2CYWnTGBWxeo31ykLxToErjVOw7y1Yde59ypo/nB9WcT0tmgIn2iQJdArdpxkBvvXcXkUUP42WeqdLy5yAlQoEtg/vzmfj7z81cZV1rMg1+Yx5hhRUGXJJLTFOgSiOe2NHDTvauoGD2EJTfPY8JI3bBC5ETpsEUZUO7OvS/t4DuPVXPKhFJ+/fnzGTu8OOiyRPKCAl0GTHskyjcf2cBvXqtj/pzx/OATZzO8WKugSKbot0kGxJuNzXztoddZW9fEP3xoJl+9/GQdzSKSYQp06VexmHPfyzv47uObGVIU5iefmsuC03XSkEh/UKBLv6nefYRv/e9GXt1+kA+eUs73PnYm40q181OkvyjQJeMaj7bzgxVbWLKqlpFDCvnuR8/gE+dVYqYuFpH+pECXjNnb1Mb/vLCNB17dRUckxk0XTucrl81i5NDCoEsTGRQU6HJC3J319U3cv3IXj6ypJ+rOVWdN4pYPzeSk8uFBlycyqCjQpU8ajrbx+Pq9PLSqluo9RygpDHFdVQVfuuQkXfZWJCAKdEmLu/NmYzN/emM/T2zYw+qdh3CH0yaV8u1rTueqsyYxcoi6VkSCpECXlGIxZ2tDM3/ZdYjVOw7xUs1+9h5pA+DUCSP4ymWzWHj6RE6ZMCLgSkWkiwJ9kHN3Gpvb2d54jDcbj7F57xE27TnCpj1HaW6PADB6aCEXnlTGRTPL+MCsMnWpiGSptALdzBYA/wmEgf9x9+92m14M/BI4FzgAfMLdd2S2VOmtaMw51NLBwWMd7G9up+FIO/uOtLGnqY36w63UHWql7mALRxPBDTC8uIBTJ4zg2nMmc3blKOZOHc20sUN1yKFIDjhuoJtZGLgLuAKoA1aZ2VJ3r05q9nngkLvPNLPFwPeAT/RHwbnK3YnGnGjXc+IRiTmRqNMZjSWGY7RHYnRGY3REYnQkntsjMdo6o7R1xmjtjNLaEaGlI0pLR5Tm9gjNbRGa2yMcaevkcEsnTa2dHGnrxP3dtQwrClMxeiiTRw/hvGmjmV42jBnlw5lRNoyK0UMU3iI5Kp0t9POBGnffBmBmS4CrgeRAvxr4VmL4t8CPzMzcU8XJiXl4VS13v7Dtrdc9zcJ7eNE16O5Jw9D1yp13hGCqdrG32sSHY+54t+eYO7FYfDiaGJ9pBSFjSFGYEcUFDC8pYHhxAWOGFTG9bBgjhxQyamgRY4cVMWZYEWOHFzG+tITxpSW6IJZInkrnN3syUJv0ug64oKc27h4xsyZgLLA/uZGZ3QzcDDBlypQ+FTx6WBGnjO+2I66HDcrk0clbnfbWuORhe7u9QderrjZdbzeMUCgxZBA2e6tNKGSEEp8TDhlmRsjiwyEzwqGkhxkFYaMgZIRDIQrCRmHYKAiFKCoIURQOURgOUVwYorggPm5IYZiSwjAlBWGGFIUpKtDl7EXkbQO6qebudwN3A1RVVfVpm/WKOeO5Ys74jNYlIpIP0tnEqwcqk15XJMalbGNmBcBI4jtHRURkgKQT6KuAWWY23cyKgMXA0m5tlgKfTQx/HHimP/rPRUSkZ8ftckn0id8CPEn8sMV73H2jmd0BrHb3pcDPgV+ZWQ1wkHjoi4jIAEqrD93dlwHLuo27PWm4Dbgus6WJiEhv6DAJEZE8oUAXEckTCnQRkTyhQBcRyRMW1NGFZtYI7Ozj28vodhZqllBdvaO6ei9ba1NdvXMidU119/JUEwIL9BNhZqvdvSroOrpTXb2junovW2tTXb3TX3Wpy0VEJE8o0EVE8kSuBvrdQRfQA9XVO6qr97K1NtXVO/1SV072oYuIyLvl6ha6iIh0o0AXEckTWRvoZnadmW00s5iZVXWbdpuZ1ZjZFjO7sof3TzezVxLtHkpc+jfTNT5kZq8nHjvM7PUe2u0ws/WJdqszXUeK+X3LzOqTalvUQ7sFiWVYY2a3DkBd3zezzWa2zsweMbNRPbQbkOV1vJ/fzIoT33FNYl2a1l+1JM2z0syeNbPqxPr/lRRtLjWzpqTv9/ZUn9UPtb3n92JxP0wsr3VmNncAajolaTm8bmZHzOyr3doM2PIys3vMrMHMNiSNG2NmK8xsa+J5dA/v/WyizVYz+2yqNsfl7ln5AGYDpwDPAVVJ4+cAa4FiYDrwJhBO8f6HgcWJ4Z8AX+7nev8duL2HaTuAsgFcdt8C/vE4bcKJZTcDKEos0zn9XNd8oCAx/D3ge0Etr3R+fuBvgZ8khhcDDw3AdzcRmJsYHgG8kaKuS4FHB2p9Svd7ARYBjxO/I+M84JUBri8M7CV+4k0gywu4GJgLbEgadydwa2L41lTrPTAG2JZ4Hp0YHt3b+WftFrq7b3L3LSkmXQ0scfd2d98O1BC/kfVbLH4D0Q8Rv2E1wH3ANf1Va2J+1wMP9tc8+sFbN/929w6g6+bf/cbdl7t7JPFyJfG7XwUlnZ//auLrDsTXpcss+ea0/cDd97j7XxLDR4FNxO/ZmwuuBn7pcSuBUWY2cQDnfxnwprv39Qz0E+buzxO/J0Sy5PWopyy6Eljh7gfd/RCwAljQ2/lnbaC/h1Q3re6+wo8FDieFR6o2mfQBYJ+7b+1hugPLzey1xI2yB8ItiX977+nhX7x0lmN/+hzxrblUBmJ5pfPzv+Pm50DXzc8HRKKL5xzglRST32dma83scTM7bYBKOt73EvQ6tZieN6qCWF5dxrv7nsTwXiDVTZEzsuwG9CbR3ZnZU8CEFJO+4e5/HOh6Ukmzxht4763z97t7vZmNA1aY2ebEX/J+qQv4b+DbxH8Bv028O+hzJzK/TNTVtbzM7BtABLi/h4/J+PLKNWY2HPgd8FV3P9Jt8l+Idys0J/aP/AGYNQBlZe33kthHdhVwW4rJQS2vd3F3N7N+O1Y80EB398v78LZ0blp9gPi/ewWJLatUbTJSo8Vviv1R4Nz3+Iz6xHODmT1C/N/9E/pFSHfZmdnPgEdTTEpnOWa8LjO7EfgIcJknOg9TfEbGl1cKvbn5eZ0N4M3PzayQeJjf7+6/7z49OeDdfZmZ/djMyty9Xy9Clcb30i/rVJoWAn9x933dJwS1vJLsM7OJ7r4n0QXVkKJNPfG+/i4VxPcf9koudrksBRYnjkCYTvwv7avJDRJB8SzxG1ZD/AbW/bXFfzmw2d3rUk00s2FmNqJrmPiOwQ2p2mZKt37La3uYXzo3/850XQuAfwaucveWHtoM1PLKypufJ/rofw5scvcf9NBmQldfvpmdT/z3uF//0KT5vSwFPpM42mUe0JTU1dDfevwvOYjl1U3yetRTFj0JzDez0Yku0vmJcb0zEHt++/IgHkR1QDuwD3gyado3iB+hsAVYmDR+GTApMTyDeNDXAL8Bivupzl8AX+o2bhKwLKmOtYnHRuJdD/297H4FrAfWJVamid3rSrxeRPwoijcHqK4a4v2EryceP+le10Aur1Q/P3AH8T84ACWJdacmsS7NGIBl9H7iXWXrkpbTIuBLXesZcEti2awlvnP5wgGoK+X30q0uA+5KLM/1JB2d1s+1DSMe0COTxgWyvIj/UdkDdCby6/PE97s8DWwFngLGJNpWAf+T9N7PJda1GuCmvsxfp/6LiOSJXOxyERGRFBToIiJ5QoEuIpInFOgiInlCgS4ikicU6CIieUKBLiKSJ/4/4vYQ7a0vu78AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(torch.sigmoid, title='Sigmoid', min=-10, max=10)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0759, 0.7109, 0.5987, 0.5498, 0.7311, 0.8808, 0.9526, 0.9998])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor([-2.5, 0.9, 0.4, 0.2, 1.0, 2.0, 3.0, 8.5]).sigmoid()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"# update loss function to use sigmoids to the inputs\n",
"# higher prediction corresponds to higher confidence \n",
"def mnist_loss(predictions, targets):\n",
" predictions = predictions.sigmoid()\n",
" return torch.where(targets==1, 1-predictions, predictions).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SGD and Mini-Batches"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# PyTorch and fastai provide a class that will do the shuffling and mini-batch collation for you , called DataLoader\n",
"coll = range(15)\n",
"list(coll)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([ 3, 12, 8, 10, 2]),\n",
" tensor([ 9, 4, 7, 14, 5]),\n",
" tensor([ 1, 13, 0, 6, 11])]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dl = DataLoader(coll, batch_size=5, shuffle=True)\n",
"list(dl)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([ 8, 9, 5, 11, 4]),\n",
" tensor([ 7, 13, 14, 2, 3]),\n",
" tensor([ 6, 12, 10, 1, 0])]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dl = DataLoader(coll, batch_size=5, shuffle=True)\n",
"list(dl)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#26) [(0, 'a'),(1, 'b'),(2, 'c'),(3, 'd'),(4, 'e'),(5, 'f'),(6, 'g'),(7, 'h'),(8, 'i'),(9, 'j')...]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# A collection that contains tuples of independent and dependent variables is known as Dataset in PyTorch\n",
"ds = L(enumerate(string.ascii_lowercase))\n",
"ds"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(tensor([19, 14, 0, 24]), ('t', 'o', 'a', 'y')),\n",
" (tensor([20, 12, 23, 8]), ('u', 'm', 'x', 'i')),\n",
" (tensor([ 9, 3, 16, 6]), ('j', 'd', 'q', 'g')),\n",
" (tensor([ 4, 7, 1, 13]), ('e', 'h', 'b', 'n')),\n",
" (tensor([ 2, 22, 5, 17]), ('c', 'w', 'f', 'r')),\n",
" (tensor([18, 10, 11, 15]), ('s', 'k', 'l', 'p')),\n",
" (tensor([25, 21]), ('z', 'v'))]"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dl = DataLoader(ds, batch_size=4, shuffle=True)\n",
"list(dl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Putting It All Together"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"weights = init_params((28*28,1))\n",
"bias = init_params(1)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([256, 784]), torch.Size([256, 1]))"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dl = DataLoader(dset, batch_size=256)\n",
"xb,yb = first(dl)\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"valid_dl = DataLoader(valid_dset, batch_size=256)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 784])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create a mini-batch of size 4 for testing\n",
"batch = train_x[:4]\n",
"batch.shape"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 2.5739],\n",
" [-1.5014],\n",
" [ 7.9317],\n",
" [-2.5406]], grad_fn=<AddBackward0>)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"# define a linear function \n",
"def linear1(xb): \n",
" return xb@weights + bias\n",
"\"\"\"\n",
"preds = linear1(batch)\n",
"preds"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.4540, grad_fn=<MeanBackward0>)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = mnist_loss(preds, train_y[:4])\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([784, 1]), tensor(-0.0106), tensor([-0.0707]))"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate the gradients\n",
"loss.backward()\n",
"weights.grad.shape, weights.grad.mean(), bias.grad"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"def calc_grad(xb, yb, model):\n",
" preds = model(xb)\n",
" loss = mnist_loss(preds, yb)\n",
" loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(-0.0213), tensor([-0.1415]))"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calc_grad(batch, train_y[:4], linear1)\n",
"weights.grad.mean(),bias.grad"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(-0.0319), tensor([-0.2122]))"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calc_grad(batch, train_y[:4], linear1)\n",
"weights.grad.mean(),bias.grad"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"weights.grad.zero_()\n",
"bias.grad.zero_();"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, lr, params):\n",
" for xb,yb in dl:\n",
" calc_grad(xb, yb, model)\n",
" for p in params:\n",
" p.data -= p.grad*lr\n",
" p.grad.zero_()"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ True],\n",
" [False],\n",
" [ True],\n",
" [False]])"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(preds>0.0).float() == train_y[:4]"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"def batch_accuracy(xb, yb):\n",
" preds = xb.sigmoid()\n",
" correct = (preds>0.5) == yb\n",
" return correct.float().mean()"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.5000)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_accuracy(linear1(batch), train_y[:4])"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"def validate_epoch(model):\n",
" accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n",
" return round(torch.stack(accs).mean().item(), 4)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6799"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"validate_epoch(linear1)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7215"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr = 1.\n",
"params = weights,bias\n",
"train_epoch(linear1, lr, params)\n",
"validate_epoch(linear1)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8743 0.9281 0.9452 0.9525 0.9574 0.9598 0.9622 0.9627 0.9652 0.9661 0.9661 0.9681 0.9691 0.9711 0.9716 0.9716 0.9725 0.972 0.9735 0.9755 "
]
}
],
"source": [
"for i in range(20):\n",
" train_epoch(linear1, lr, params)\n",
" print(validate_epoch(linear1), end=' ')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating an Optimizer"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"#In PyTorch the below two functions init_param and linear are combined together as nn.Linear\n",
"\n",
"\"\"\"\n",
"# lets define a function to initialize random weights for every PIXEL in an image which is 784 pixels\n",
"def init_params(size, std=1.0): \n",
" return (torch.randn(size)*std).requires_grad_()\n",
"\n",
"# define a linear function \n",
"def linear(xb): \n",
" return xb@weights + bias\n",
"\"\"\"\n",
"\n",
"linear_model = nn.Linear(28*28,1)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Linear(in_features=784, out_features=1, bias=True)"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# this can be though of as a model with a single layer having 784 input and 1 output\n",
"linear_model"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 784]), torch.Size([1]))"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w,b = linear_model.parameters()\n",
"w.shape, b.shape"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.nn.parameter.Parameter, torch.nn.parameter.Parameter)"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(w) , type(b)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([-0.0027], requires_grad=True)"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"# create an optimizer\n",
"class BasicOptim:\n",
" def __init__(self,params,lr): self.params,self.lr = list(params),lr\n",
"\n",
" def step(self, *args, **kwargs):\n",
" for p in self.params: p.data -= p.grad.data * self.lr\n",
"\n",
" def zero_grad(self, *args, **kwargs):\n",
" for p in self.params: p.grad = None"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"# pass model's parameter to the optimizer\n",
"opt = BasicOptim(linear_model.parameters(), lr)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"# training loop can be simplified as -\n",
"def train_epoch(model):\n",
" for xb,yb in dl:\n",
" calc_grad(xb, yb, model)\n",
" opt.step()\n",
" opt.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.3314"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"def validate_epoch(model):\n",
" accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n",
" return round(torch.stack(accs).mean().item(), 4)\n",
"\"\"\"\n",
"validate_epoch(linear_model)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# training loop in a function \n",
"def train_model(model, epochs):\n",
" for i in range(epochs):\n",
" train_epoch(model)\n",
" print(validate_epoch(model), end=' ')"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4932 0.8018 0.8462 0.9136 0.9316 0.9468 0.9555 0.9624 0.9653 0.9668 0.9697 0.9717 0.9726 0.9746 0.9761 0.9765 0.9775 0.978 0.9785 0.9785 "
]
}
],
"source": [
"train_model(linear_model, 20)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4932 0.8242 0.8472 0.9141 0.9341 0.9482 0.9555 0.9624 0.9658 0.9678 0.9697 0.9717 0.9736 0.9751 0.9761 0.9765 0.9775 0.978 0.9785 0.9785 "
]
}
],
"source": [
"# fastai provides the SGD class which does the samething as BasicOptim\n",
"\n",
"linear_model = nn.Linear(28*28,1)\n",
"opt = SGD(linear_model.parameters(), lr)\n",
"train_model(linear_model, 20)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# fastai also provides Learner.fit which can be used instead of train_model\n",
"# First create a DataLoaders \n",
"dls = DataLoaders(dl, valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"def batch_accuracy(xb, yb):\n",
" preds = xb.sigmoid()\n",
" correct = (preds>0.5) == yb\n",
" return correct.float().mean()\n",
" \n",
"def mnist_loss(predictions, targets):\n",
" predictions = predictions.sigmoid()\n",
" return torch.where(targets==1, 1-predictions, predictions).mean()\n",
"\n",
"\"\"\"\n",
"\n",
"learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n",
" loss_func=mnist_loss, metrics=batch_accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch train_loss valid_loss batch_accuracy time \n",
"0 0.637205 0.503335 0.495584 00:00 \n",
"1 0.484020 0.211543 0.814524 00:00 \n",
"2 0.179104 0.168032 0.850343 00:00 \n",
"3 0.079153 0.103017 0.913150 00:00 \n",
"4 0.042389 0.076197 0.934249 00:00 \n",
"5 0.028032 0.061356 0.947988 00:00 \n",
"6 0.022124 0.052038 0.956330 00:00 \n",
"7 0.019487 0.045821 0.962218 00:00 \n",
"8 0.018139 0.041443 0.965653 00:00 \n",
"9 0.017323 0.038212 0.967125 00:00 \n",
"10 0.016745 0.035727 0.970069 00:00 \n",
"11 0.016288 0.033746 0.971541 00:00 \n",
"12 0.015904 0.032123 0.973503 00:00 \n",
"13 0.015576 0.030765 0.974975 00:00 \n",
"14 0.015292 0.029612 0.975957 00:00 \n"
]
}
],
"source": [
"learn.fit(15, lr=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adding a Nonlinearity"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"# below is a very simple neural netwrok containing 1 linear layer then a activation function and the 2nd layer\n",
"# The activation used here res.max(tensor(0.0)) is called rectified linear unit (ReLU)\n",
"# it replaces every -ve number with a zero\n",
"def simple_net(xb): \n",
" res = xb@w1 + b1\n",
" res = res.max(tensor(0.0))\n",
" res = res@w2 + b2\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"# lets randomly initilialize some weights and parameters\n",
"w1 = init_params((28*28,30))\n",
"b1 = init_params(30)\n",
"w2 = init_params((30,1))\n",
"b2 = init_params(1)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAiTklEQVR4nO3deXxU9b3/8ddHVlmUJRHZF0EBRSCkoKJ1q4pUxa6C2KrVUhFqrd2oXrU/7GbtdceF2/KwLbs7tSjiSq2FEkJYZQkgS0AS9p2Q5PP7Yw73jjEhQ5jJmZm8n4/HPDJzlpn3HMJnTr7nzOeYuyMiIunrpLADiIhIYqnQi4ikORV6EZE0p0IvIpLmVOhFRNJc3bADVCQjI8M7deoUdgwRkZSxYMGCbe6eWdG8pCz0nTp1IicnJ+wYIiIpw8zWVzZPQzciImlOhV5EJM2p0IuIpDkVehGRNKdCLyKS5qos9GbW3szeN7PlZrbMzH5UwTJmZk+aWb6ZLTazrKh5N5vZ6uB2c7zfgIiIHFssp1eWAD9x91wzawosMLPZ7r48apmrgW7BbQDwLDDAzFoADwLZgAfrznD3nXF9FyIiUqkq9+jdfYu75wb39wKfAG3LLTYE+KtHzAWamVlr4CpgtrvvCIr7bGBQXN+BiEga+M+6Hfzpn2tJROv44xqjN7NOQF9gXrlZbYGNUY83BdMqm17Rc48wsxwzyykqKjqeWCIiKa1w7yFGTc5l0rwNHDxSGvfnj7nQm1kT4GXgbnffE+8g7j7e3bPdPTszs8Jv8YqIpJ2S0jJ+OHkhew8d4dmbsmhUP/4NC2Iq9GZWj0iRn+Tur1SwSAHQPupxu2BaZdNFRAT449urmLduB7+5vhfdTz8lIa8Ry1k3BvwZ+MTdH61ksRnAd4Ozb84Ddrv7FmAWcKWZNTez5sCVwTQRkVpv9vKtPPfhGob178A3+rVL2OvE8jfCQOA7wBIzywum3Qt0AHD354CZwGAgHzgA3BrM22FmDwHzg/XGuvuOuKUXEUlR67fv557peZzT9hQevLZnQl+rykLv7h8BVsUyDoyqZN4EYEK10omIpKFDR0oZOTEXA54d3o+G9eok9PWSsk2xiEg6e/D1ZSzfsocJt2TTvkWjhL+eWiCIiNSg6TkbmZazkdGXduWy7q1q5DVV6EVEasiyzbu5/7WlDOzakh9fcWaNva4KvYhIDdh98Ah3TsqleaP6PDG0L3VOOuahz7jSGL2ISIK5Oz99cREFOw8y7QfnkdGkQY2+vvboRUQS7Pk5a5m9fCv3Du5Bv44tavz1VehFRBJo7trt/OGtFXy1V2tuHdgplAwq9CIiCVK45xCjJy+kU0ZjHv7muUQaDdQ8jdGLiCRASWkZo6csZP/hEiZ/fwBNGoRXblXoRUQS4JFZK/nPuh08dkNvzmzVNNQsGroREYmzWcs+4/k5axk+oANf65u4ZmWxUqEXEYmjT7ft56fTF3Fuu1N5IMHNymKlQi8iEieHjpQyclIudeoYzwzPokHdxDYri5XG6EVE4uT+15ay4rM9TLjlS7RrnvhmZbHSHr2ISBxMm7+BFxds4oeXduXSs04LO87nqNCLiJygpQW7uf/1ZVzULYMffaXmmpXFqsqhGzObAFwDFLr7ORXM/xkwPOr5egCZwdWlPgX2AqVAibtnxyu4iEgyONqsrGXj+jx+Q58abVYWq1j26F8ABlU2090fcfc+7t4H+CXwYbnLBV4azFeRF5G0Ulbm/GT6IrbsPsi44Vm0rOFmZbGqstC7+xwg1uu8DgOmnFAiEZEU8fyctbzzyVbuG9yDrA7Nw45TqbiN0ZtZIyJ7/i9HTXbgbTNbYGYjqlh/hJnlmFlOUVFRvGKJiCTEx2u28cisFVxzbmtuvqBT2HGOKZ4HY68F/lVu2OZCd88CrgZGmdmXK1vZ3ce7e7a7Z2dmZsYxlohIfG3dc4i7piykc0ZjHv5GeM3KYhXPQj+UcsM27l4Q/CwEXgX6x/H1RERq3JHSMkZPzuVAcSnP3dSPxiE2K4tVXAq9mZ0KXAy8HjWtsZk1PXofuBJYGo/XExEJyx/eWsH8T3fyu6/3olvIzcpiFcvplVOAS4AMM9sEPAjUA3D354LFvga87e77o1ZtBbwa/ElTF5js7m/FL7qISM16a+kW/uef6/jOeR0Z0qdt2HFiVmWhd/dhMSzzApHTMKOnrQV6VzeYiEgyWVu0j5++uJje7ZvxX9f0CDvOcdE3Y0VEqnCwuJQ7J+VSL8malcUq+Y8iiIiEyN2577UlrNy6lxdu7U/bZieHHem4aY9eROQYps7fyCu5Bdx1WTcuPjM1T/1WoRcRqcTSgt08OCPSrOyuy7uFHafaVOhFRCqw+8AR7pi4gIzG9XliaN+kbFYWK43Ri4iUU1bm3DM9j617DjH9B+fTonH9sCOdEO3Ri4iU8+yHa3h3RSH/9dWe9E3iZmWxUqEXEYnyr/xt/PfbK7m2dxu+e37HsOPEhQq9iEjgs92RZmVdMpvw+6/3SvpmZbHSGL2ICP/XrOzgkVKmDs9KiWZlsUqfdyIicgJ+/+YKctbv5MlhfVOmWVmsNHQjIrXezCVb+PNH67jlgk5c17tN2HHiToVeRGq1NUX7+PlLi+nboRn3Dk6tZmWxUqEXkVrrQHEJIycuoH7dkxh3Yxb166ZnSdQYvYjUSu7Ofa8uZXXhPv76vf60ScFmZbFKz48vEZEqTJq3gVcXFnD35WdyUbfUbFYWqyoLvZlNMLNCM6vwMoBmdomZ7TazvOD2QNS8QWa20szyzWxMPIOLiFTX4k27GPv35Vx8ZiY/vKxr2HESLpY9+heAQVUs80937xPcxgKYWR1gHHA10BMYZmY9TySsiMiJ2rm/mJETc8ls2oDHb+jDSSncrCxWVRZ6d58D7KjGc/cH8t19rbsXA1OBIdV4HhGRuCgrc348PY/CvYcYNzyL5inerCxW8RqjP9/MFpnZm2Z2djCtLbAxaplNwbQKmdkIM8sxs5yioqI4xRIR+T9Pv5/PByuLeOCanvRp3yzsODUmHoU+F+jo7r2Bp4DXqvMk7j7e3bPdPTszM70PjIhIzfto9TYee2cV1/dpw03npUezslidcKF39z3uvi+4PxOoZ2YZQAHQPmrRdsE0EZEatXnXQe6aupCumU34bRo1K4vVCRd6Mzvdgq1mZv2D59wOzAe6mVlnM6sPDAVmnOjriYgcj+KSMkZNzuXwkVKe+04/GtWvfV8fqvIdm9kU4BIgw8w2AQ8C9QDc/Tngm8BIMysBDgJD3d2BEjMbDcwC6gAT3H1ZQt6FiEglfjvzExZu2MW4G7M4I7NJ2HFCUWWhd/dhVcx/Gni6knkzgZnViyYicmLeWLyZFz7+lFsHduKr57YOO05o9M1YEUlL+YX7+MVLi8nq0IxfXp2ezcpipUIvImln/+FIs7IG9eowbnj6NiuLVe07KiEiac3duffVJeQX7eNv3xtA61PTt1lZrGr3x5yIpJ2Jc9fzet5m7vnKmVzYLSPsOElBhV5E0kbexl2MfWM5l5yVyahL079ZWaxU6EUkLezcX8yoSbmc1rRhrWlWFiuN0YtIyisrc+6elkfR3sO8NPJ8mjWqHc3KYqU9ehFJeU+9l8+Hq4p44NqenNuuWdhxko4KvYiktDmrinj83VV8vW9bhg/oEHacpKRCLyIpq2DXQX40dSFnntaU33yt9jUri5UKvYikpOKSMkZNyuVIqfPsTVmcXL9O2JGSlg7GikhK+s0/lpO3cRfPDM+iSy1tVhYr7dGLSMqZsWgzf/n3em6/sDODe9XeZmWxUqEXkZSyeutexry8mOyOzfnF1d3DjpMSVOhFJGXsP1zCyEm5NKofaVZWr45KWCyq3EpmNsHMCs1saSXzh5vZYjNbYmYfm1nvqHmfBtPzzCwnnsFFpHZxd8a8soS1Rft4cmhfWp3SMOxIKSOWj8MXgEHHmL8OuNjdewEPAePLzb/U3fu4e3b1IoqIwF//vZ6/L9rMT648iwu6qlnZ8YjlClNzzKzTMeZ/HPVwLpGLgIuIxE3uhp38+h/Lubz7aYy8+Iyw46SceA9w3Qa8GfXYgbfNbIGZjTjWimY2wsxyzCynqKgozrFEJFXt2F/M6Em5tDqlIY9+W83KqiNu59Gb2aVECv2FUZMvdPcCMzsNmG1mK9x9TkXru/t4gmGf7Oxsj1cuEUldpWXOj6YuZNv+Yl4ZeQGnNqoXdqSUFJc9ejM7F/gTMMTdtx+d7u4Fwc9C4FWgfzxeT0RqhyffXc0/V2/j/113Nue0PTXsOCnrhAu9mXUAXgG+4+6roqY3NrOmR+8DVwIVnrkjIlLeBysLefK91Xwjqx1Dv9Q+7DgprcqhGzObAlwCZJjZJuBBoB6Auz8HPAC0BJ4JGgqVBGfYtAJeDabVBSa7+1sJeA8ikmY27TzA3dPyOKtVU359/TlqVnaCYjnrZlgV828Hbq9g+lqg9xfXEBGp3OGSUkZNXkhpqfPsTf3UrCwO1NRMRJLKr9/4hEUbd/HcTVl0zmgcdpy0oO8Pi0jSeD2vgL/NXc/3L+rMoHPUrCxeVOhFJCms2rqXMS8v4UudmvPzQWpWFk8q9CISun2HSxg5cQGNG9Tl6RvVrCzetDVFJFTuzi9eXsy6bft5apialSWCCr2IhOqFjz/lH4u38LOrunP+GS3DjpOWVOhFJDQL1u/kN//4hK/0aMUdF3cJO07aUqEXkVBs33eY0ZNzadPsZP772731pagE0nn0IlLjIs3K8th+tFnZyWpWlkjaoxeRGvfEO6v4KH8bDw1Rs7KaoEIvIjXq/RWFPPlePt/q144bvtQh7Di1ggq9iNSYjTsizcp6tD6Fh64/J+w4tYYKvYjUiEizslzKypxnh2fRsJ6aldUUHYwVkRox9u/LWbxpN89/px+d1KysRmmPXkQS7tWFm5g0bwM/+HIXrjr79LDj1Doq9CKSUCs/28u9ryylf+cW/Oyqs8KOUyvFVOjNbIKZFZpZhZcCtIgnzSzfzBabWVbUvJvNbHVwuzlewUUk+e09dISRExfQpGFdnh7Wl7pqVhaKWLf6C8CgY8y/GugW3EYAzwKYWQsilx4cQOTC4A+aWfPqhhWR1HG0Wdn6HQd4elhfTlOzstDEVOjdfQ6w4xiLDAH+6hFzgWZm1hq4Cpjt7jvcfScwm2N/YIhImpjwr0+ZueQzfn7VWQzoomZlYYrX31FtgY1RjzcF0yqb/gVmNsLMcswsp6ioKE6xRCQMC9bv4HczP+HKnq0Y8WU1Kwtb0gyYuft4d8929+zMzMyw44hINW3bd5g7J+XStvnJPPItNStLBvEq9AVA+6jH7YJplU0XkTQUaVa2kF0HjvDs8H5qVpYk4lXoZwDfDc6+OQ/Y7e5bgFnAlWbWPDgIe2UwTUTS0GOzV/Gv/O08dP059GxzSthxJBDTN2PNbApwCZBhZpuInElTD8DdnwNmAoOBfOAAcGswb4eZPQTMD55qrLsf66CuiKSo91Zs5en38xn6pfZ8O7t91StIjYmp0Lv7sCrmOzCqknkTgAnHH01EUsXGHQe4e2oeZ7c5hV9dd3bYcaScpDkYKyKp6dCRUkZOWgDAs8P7qVlZElJTMxE5IWPfWM7Sgj38z3ez6dCyUdhxpALaoxeRansldxOT523gjovP4IqercKOI5VQoReRalnx2R7ufXUJAzq34KdXnhl2HDkGFXoROW57Dh1h5MRcTmlYj6duVLOyZKcxehE5Lu7Oz19czIYdB5jy/fM4ramalSU7fQyLyHH580freGvZZ4wZ1J3+nVuEHUdioEIvIjGb/+kOfvfmCgadfTq3X9Q57DgSIxV6EYlJ0d7DjJqUS/vmJ/OHb52rZmUpRGP0IlKlktIy7pqykD2HjvCX7/XnlIZqVpZKVOhFpEqPzl7Fv9du54/f6k2P1mpWlmo0dCMixzR7+Vae+WANw/q355v92oUdR6pBhV5EKrVh+wHumZ7HOW1P4cFr1awsVanQi0iFjjYrM9SsLNVpjF5EKvSrGctYtnkPf745m/Yt1KwslWmPXkS+4MWcjUydv5E7LzmDy3uoWVmqi6nQm9kgM1tpZvlmNqaC+Y+ZWV5wW2Vmu6LmlUbNmxHH7CKSAMs37+G/XlvK+V1acs8ValaWDqocujGzOsA44ApgEzDfzGa4+/Kjy7j7j6OW/yHQN+opDrp7n7glFpGE2XPoCHdOWkCzRvV4cpialaWLWP4V+wP57r7W3YuBqcCQYyw/DJgSj3AiUnPcnZ+9uIhNOw8y7sYsMps2CDuSxEkshb4tsDHq8aZg2heYWUegM/Be1OSGZpZjZnPN7PrKXsTMRgTL5RQVFcUQS0Ti6X/+uZZZy7Yy5uruZHdSs7J0Eu+/y4YCL7l7adS0ju6eDdwIPG5mZ1S0oruPd/dsd8/OzMyMcywROZZ5a7fz8FsrGdzrdG67UM3K0k0shb4AaB/1uF0wrSJDKTds4+4Fwc+1wAd8fvxeREJWuPcQo6cspGOLRjz8DTUrS0exFPr5QDcz62xm9YkU8y+cPWNm3YHmwL+jpjU3swbB/QxgILC8/LoiEo6S0jJ+OHkhew8d4ZmbsmiqZmVpqcqzbty9xMxGA7OAOsAEd19mZmOBHHc/WvSHAlPd3aNW7wE8b2ZlRD5Ufh99to6IhOuPb69i3rodPPrt3nQ/Xc3K0lVM34x195nAzHLTHij3+FcVrPcx0OsE8olIgry97DOe+3ANNw7owNez1KwsnekkWZFaaP32/fzkxUX0ansqD1zTM+w4kmAq9CK1zKEjpYycmMtJZjwzPEvNymoBNTUTqWUeeH0py7fsYcItalZWW2iPXqQWmT5/I9NzNjH60q5c1l3NymoLFXqRWmLZ5t3c//pSBnZtyY/VrKxWUaEXqQV2HzzCnZNyad6oPk8M7Uudk/SlqNpEY/Qiac7d+emLiyjYeZBpPziPjCZqVlbbaI9eJM09P2cts5dv5d7BPejXUc3KaiMVepE0Nnftdv7w1gq+em5rbh3YKew4EhIVepE0VbjnEKMnL6RTRmM1K6vlNEYvkoZKSssYPWUh+w+XMOn2ATRpoP/qtZn+9UXS0COzVvKfdTt4/IY+nHV607DjSMg0dCOSZt5a+hnPz1nLTed14Pq+FV4MTmoZFXqRNPLptv387MVF9G53KverWZkEVOhF0sTB4lLumLiAOnWMccOzaFBXzcokIqZCb2aDzGylmeWb2ZgK5t9iZkVmlhfcbo+ad7OZrQ5uN8czvIhEuDv3v76UlVv38tgNfWjXXM3K5P9UeTDWzOoA44ArgE3AfDObUcGVoqa5++hy67YAHgSyAQcWBOvujEt6EQFg2vyNvLRgE3dd1pVLzzot7DiSZGLZo+8P5Lv7WncvBqYCQ2J8/quA2e6+Iyjus4FB1YsqIhVZWrCbB2Ys46JuGfzoK2pWJl8US6FvC2yMerwpmFbeN8xssZm9ZGbtj3NdzGyEmeWYWU5RUVEMsURk94Ej3DFxAS0b1+fxG/qoWZlUKF4HY/8OdHL3c4nstf/leJ/A3ce7e7a7Z2dmZsYplkj6Kitz7pmex9Y9hxg3PIuWalYmlYil0BcA7aMetwum/S933+7uh4OHfwL6xbquiFTPsx+u4d0Vhdw3uAdZHZqHHUeSWCyFfj7Qzcw6m1l9YCgwI3oBM2sd9fA64JPg/izgSjNrbmbNgSuDaSJyAj5es43/fnsl1/Zuw80XdAo7jiS5Ks+6cfcSMxtNpEDXASa4+zIzGwvkuPsM4C4zuw4oAXYAtwTr7jCzh4h8WACMdfcdCXgfIrXGZ7sPcdeUhXTOaMzvvt5LzcqkSubuYWf4guzsbM/JyQk7hkjSOVJaxrDxc1m+ZQ+vjxpIt1bqYyMRZrbA3bMrmqemZiIp5OE3V5CzfidPDO2jIi8xUwsEkRTx5pIt/OmjdXz3/I4M6aNmZRI7FXqRFLC2aB8/e2kxvds3476v9gg7jqQYFXqRJHewuJSRE3OpV8d4Rs3KpBo0Ri+SxNyd+15bwqrCvbxwa3/aNjs57EiSgrRHL5LEpvxnI6/kFnDXZd24+Ex9Y1yqR4VeJEkt3rSLXwXNyu66vFvYcSSFqdCLJKFdB4oZOTGXjCb1eWJoXzUrkxOiMXqRJFNW5vx4Wh6Few/x4h0X0KJx/bAjSYrTHr1Iknnmg3zeX1nE/df0pE/7ZmHHkTSgQi+SRP6Vv41HZ6/iut5t+M55HcOOI2lChV4kSRxtVtYls4malUlcaYxeJAkcKS1j1ORcDh4pZdpNWTRuoP+aEj/6bRJJAr+buYIF63fy1LC+dD1NzcokvjR0IxKyNxZvZsK/1nHLBZ24tnebsONIGlKhFwlRfuE+fvHSYvp2aMa9g9WsTBIjpkJvZoPMbKWZ5ZvZmArm32Nmy81ssZm9a2Ydo+aVmllecJtRfl2R2upAcQl3TlpAg3p1GHdjFvXrar9LEqPKMXozqwOMA64ANgHzzWyGuy+PWmwhkO3uB8xsJPAH4IZg3kF37xPf2CKpzd2595UlrC7cx1+/1582alYmCRTLLkR/IN/d17p7MTAVGBK9gLu/7+4HgodzgXbxjSmSXibO28BreZu5+/IzuaibmpVJYsVS6NsCG6MebwqmVeY24M2oxw3NLMfM5prZ9ZWtZGYjguVyioqKYoglkpoWbdzFQ39fziVnZfLDy7qGHUdqgbieXmlmNwHZwMVRkzu6e4GZdQHeM7Ml7r6m/LruPh4YD5GLg8czl0iy2Lm/mDsn5ZLZtAGPfbsPJ6lZmdSAWPboC4D2UY/bBdM+x8y+AtwHXOfuh49Od/eC4Oda4AOg7wnkFUlZZWXOj6fnUbT3MM8Mz6K5mpVJDYml0M8HuplZZzOrDwwFPnf2jJn1BZ4nUuQLo6Y3N7MGwf0MYCAQfRBXpNZ4+v18PlhZxAPX9qS3mpVJDapy6MbdS8xsNDALqANMcPdlZjYWyHH3GcAjQBPgxaA/xwZ3vw7oATxvZmVEPlR+X+5sHZFa4Z+ri3jsnVV8rW9bhg/oEHYcqWXMPfmGw7Ozsz0nJyfsGCJxsXnXQa556iMymtTntVEDaVRfnUck/sxsgbtnVzRP39AQSaDikkizsuKSMp69qZ+KvIRCv3UiCfTbmZ+wcMMuxt2YxRmZTcKOI7WU9uhFEmTGos288PGnfG9gZ756buuw40gtpkIvkgD5hXsZ8/Ji+nVszi8Hdw87jtRyKvQicbb/cAkjJ+ZyctCsrF4d/TeTcGmMXiSO3J1fvrKENUX7+NttAzj91IZhRxLRHr1IPP1t7npmLNrMPVecycCuGWHHEQFU6EXiZuGGnTz0xnIu634ad16iZmWSPFToReJgx/5iRk3KpdUpDXn0273VrEySisboRU5QaZlz97Q8tu0r5uWRF9CskZqVSXJRoRc5QU+9t5o5q4r47dd60avdqWHHEfkCDd2InIAPVxXxxLur+XpWW4b1b1/1CiIhUKEXqabNuw5y99SFnNWqKb+5vhdB51aRpKNCL1INxSVl3DkplyOlzjPDszi5fp2wI4lUSmP0ItXwm38sJ2/jLp67KYsualYmSU579CLH6fW8Av7y7/XcfmFnBp2jZmWS/GIq9GY2yMxWmlm+mY2pYH4DM5sWzJ9nZp2i5v0ymL7SzK6KY3aRGvfW0i388pUlfKlTc35xtZqVSWqocujGzOoA44ArgE3AfDObUe6SgLcBO929q5kNBR4GbjCznkSuMXs20AZ4x8zOdPfSeL8RkUQq3HuIB19fxptLP+PsNqfwtJqVSQqJZYy+P5Dv7msBzGwqMITPX+R7CPCr4P5LwNMWOQVhCDDV3Q8D68wsP3i+f8cn/udd+9RHHDqizxCJvy27D1FcWsbPB53F9y/qoiIvKSWWQt8W2Bj1eBMwoLJlgouJ7wZaBtPnllu3bUUvYmYjgBEAHTpU7+LJZ2Q2pri0rFrrihxLn/bN+MHFZ9D1NB14ldSTNGfduPt4YDxELg5ened4fGjfuGYSEUkHsfz9WQBEf+WvXTCtwmXMrC5wKrA9xnVFRCSBYin084FuZtbZzOoTObg6o9wyM4Cbg/vfBN5zdw+mDw3OyukMdAP+E5/oIiISiyqHboIx99HALKAOMMHdl5nZWCDH3WcAfwb+Fhxs3UHkw4BguelEDtyWAKN0xo2ISM2yyI53csnOzvacnJywY4iIpAwzW+Du2RXN0zliIiJpToVeRCTNqdCLiKQ5FXoRkTSXlAdjzawIWF/N1TOAbXGMEy/KdXyU6/go1/FJx1wd3T2zohlJWehPhJnlVHbkOUzKdXyU6/go1/Gpbbk0dCMikuZU6EVE0lw6FvrxYQeohHIdH+U6Psp1fGpVrrQboxcRkc9Lxz16ERGJokIvIpLmUr7Qm9kjZrbCzBab2atm1qyS5Y55gfME5PqWmS0zszIzq/R0KTP71MyWmFmemSW8k9tx5Krp7dXCzGab2ergZ/NKlisNtlWemZVvlx3PPMd8/0Hr7WnB/Hlm1ilRWY4z1y1mVhS1jW6vgUwTzKzQzJZWMt/M7Mkg82Izy0p0phhzXWJmu6O21QM1lKu9mb1vZsuD/4s/qmCZ+G4zd0/pG3AlUDe4/zDwcAXL1AHWAF2A+sAioGeCc/UAzgI+ALKPsdynQEYNbq8qc4W0vf4AjAnuj6no3zGYt68GtlGV7x+4E3guuD8UmJYkuW4Bnq6p36fgNb8MZAFLK5k/GHgTMOA8YF6S5LoEeKMmt1Xwuq2BrOB+U2BVBf+Ocd1mKb9H7+5vu3tJ8HAukatYlfe/Fzh392Lg6AXOE5nrE3dfmcjXqI4Yc9X49gqe/y/B/b8A1yf49Y4llvcfnfcl4HIzsyTIVePcfQ6R61BUZgjwV4+YCzQzs9ZJkCsU7r7F3XOD+3uBT/jitbTjus1SvtCX8z0in4LlVXSB8wovUh4CB942swXBBdKTQRjbq5W7bwnufwa0qmS5hmaWY2Zzzez6BGWJ5f3/7zLBjsZuoGWC8hxPLoBvBH/uv2Rm7SuYX9OS+f/f+Wa2yMzeNLOza/rFgyG/vsC8crPius2S5uLgx2Jm7wCnVzDrPnd/PVjmPiJXsZqUTLlicKG7F5jZacBsM1sR7ImEnSvujpUr+oG7u5lVdt5vx2B7dQHeM7Ml7r4m3llT2N+BKe5+2Mx+QOSvjstCzpSscon8Pu0zs8HAa0Qud1ojzKwJ8DJwt7vvSeRrpUShd/evHGu+md0CXANc7sEAVzkJuUh5VblifI6C4Gehmb1K5M/zEyr0cchV49vLzLaaWWt33xL8iVpYyXMc3V5rzewDIntD8S70sbz/o8tsMrO6wKnA9jjnOO5c7h6d4U9Ejn2ELSG/Tycquri6+0wze8bMMtw94c3OzKwekSI/yd1fqWCRuG6zlB+6MbNBwM+B69z9QCWLxXKB8xpnZo3NrOnR+0QOLFd4hkANC2N7RV9g/mbgC395mFlzM2sQ3M8ABhK5HnG8xfL+o/N+E3ivkp2MGs1Vbhz3OiLjv2GbAXw3OJPkPGB31DBdaMzs9KPHVcysP5F6mOgPa4LX/DPwibs/Wsli8d1mNX3EOd43IJ/IWFZecDt6JkQbYGbUcoOJHN1eQ2QII9G5vkZkXO0wsBWYVT4XkbMnFgW3ZcmSK6Tt1RJ4F1gNvAO0CKZnA38K7l8ALAm21xLgtgTm+cL7B8YS2aEAaAi8GPz+/QfokuhtFGOu3wW/S4uA94HuNZBpCrAFOBL8bt0G3AHcEcw3YFyQeQnHOAuthnONjtpWc4ELaijXhUSOzS2OqluDE7nN1AJBRCTNpfzQjYiIHJsKvYhImlOhFxFJcyr0IiJpToVeRCTNqdCLiKQ5FXoRkTT3/wExpUYJ/pjBsQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# the ReLU function in PyTorch is available as F.relu\n",
"plot_function(F.relu)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"# The sample simple_net network written ealier can be replace with the below code in PyTorch \n",
"simple_net = nn.Sequential(\n",
" nn.Linear(28*28,30),\n",
" nn.ReLU(),\n",
" nn.Linear(30,1)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, simple_net, opt_func=SGD,\n",
" loss_func=mnist_loss, metrics=batch_accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch train_loss valid_loss batch_accuracy time \n",
"0 0.315269 0.413630 0.504416 00:00 \n",
"1 0.146809 0.228842 0.804711 00:00 \n",
"2 0.080703 0.113013 0.920020 00:00 \n",
"3 0.052696 0.075948 0.944553 00:00 \n",
"4 0.039812 0.059318 0.959274 00:00 \n",
"5 0.033291 0.050055 0.965162 00:00 \n",
"6 0.029574 0.044236 0.966634 00:00 \n",
"7 0.027171 0.040259 0.967125 00:00 \n",
"8 0.025446 0.037358 0.969087 00:00 \n",
"9 0.024111 0.035135 0.971541 00:00 \n",
"10 0.023031 0.033363 0.972522 00:00 \n",
"11 0.022128 0.031908 0.973994 00:00 \n",
"12 0.021358 0.030682 0.974975 00:00 \n",
"13 0.020692 0.029630 0.975957 00:00 \n",
"14 0.020108 0.028711 0.976448 00:00 \n",
"15 0.019590 0.027900 0.977429 00:00 \n",
"16 0.019126 0.027178 0.978410 00:00 \n",
"17 0.018708 0.026529 0.978410 00:00 \n",
"18 0.018327 0.025945 0.978901 00:00 \n",
"19 0.017978 0.025415 0.978901 00:00 \n",
"20 0.017657 0.024932 0.979882 00:00 \n",
"21 0.017360 0.024491 0.980373 00:00 \n",
"22 0.017084 0.024086 0.981354 00:00 \n",
"23 0.016826 0.023714 0.981354 00:00 \n",
"24 0.016584 0.023371 0.981354 00:00 \n",
"25 0.016357 0.023055 0.981845 00:00 \n",
"26 0.016143 0.022762 0.981845 00:00 \n",
"27 0.015940 0.022491 0.981845 00:00 \n",
"28 0.015748 0.022238 0.981845 00:00 \n",
"29 0.015566 0.022003 0.982826 00:00 \n",
"30 0.015393 0.021784 0.982336 00:00 \n",
"31 0.015227 0.021578 0.982336 00:00 \n",
"32 0.015069 0.021386 0.982336 00:00 \n",
"33 0.014918 0.021206 0.983317 00:00 \n",
"34 0.014773 0.021036 0.983317 00:00 \n",
"35 0.014634 0.020876 0.983317 00:00 \n",
"36 0.014500 0.020725 0.983317 00:00 \n",
"37 0.014371 0.020582 0.983317 00:00 \n",
"38 0.014247 0.020446 0.983317 00:00 \n",
"39 0.014127 0.020318 0.983317 00:00 \n"
]
}
],
"source": [
"learn.fit(40, 0.1)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD5CAYAAAA3Os7hAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAX+0lEQVR4nO3de3BcZ3nH8e+jXa1ky5YvsWInthM7xBCcDCFETSgwQKEJSQoJUGCSSaekkxLKJNCWQglDJ6SZMrQdrjMNl1BS7gkhvbnUTBJIGIYOF8vkaodEinOxjSMpsa3rrrSXp3/sWWm9Xklrea2jc87vM6PZs2ePdx+dsX569Z73vK+5OyIiEn0tYRcgIiLNoUAXEYkJBbqISEwo0EVEYkKBLiISEwp0EZGYSM91gJndDrwVGHD3c+q8bsAXgcuAceAad//NXO+7Zs0a37Rp0zEXLCKSZDt37nzB3bvqvTZnoAPfAP4F+NYMr18KbAm+LgS+HDzOatOmTfT09DTw8SIiUmFmz8702pxdLu7+M+DgLIdcAXzLy34JrDSzU469TBEROR7N6ENfD+yter4v2HcUM7vOzHrMrGdwcLAJHy0iIhULelHU3W9z92537+7qqtsFJCIi89SMQN8PbKx6viHYJyIiC6gZgb4N+FMrezUw5O4HmvC+IiJyDBoZtngH8EZgjZntAz4JtAK4+1eA7ZSHLPZRHrb4ZyeqWBERmdmcge7uV83xugPXN60iERGZl0bGoYtIgrk72XyR0YkCo7kCYxPl7fHJAi0tRluqhdZ0C5lUC5l0C62pFtqCR7P5f+5kocRIrsDYZKHqs8vbYxNFiqVS877JBfbml6/l3I0rm/6+CnSRGZRKzni+yPhkgclCiclCiXzRy9vF0tRjLl9kbKISNkVGJ/JToTeaK1A4juBxZ+qz8sVS1bZP7TtRHKa+t9IiXAfneH5ZhO3kznYFusSPuzNRKE2F32gQjGOThXLrbKIcKBOF4izvAUX3qtAtB9/EVPAVmW1hrpJDNj/9+ZUW4NhkYdZ/N5NUi9GRSbGsLU1HW5pM+vjGHrQGLd+lmTQrg5ZwpUXcmrITGmxt6fL3say9/L0sa0uxrK2VjrYUHZk0RXfyhdIRv3Qmqn75HY/WlLE8+NyOtjTL26a3OzIp0ilNRVVLgS5zmvqTOzcddiNBK3RsosBIEMIT+RKTxeJU63GiEq7BD3jlT+exynvk8oxNFik2qfmXbjEy6ek/+zMN/ulvZizNpFi5NMOGVUvLYVUVIEvb0rQFoTr13lNdDDYVeh1taZa3p2lLt2BRbj5KZCnQE6JQLLeCD45Ncmh8kkNjeQ6OT3JobHLqcSibD1qo5aCu9FmOTTb+J3eqxWhNWRB2KTKp6ZBdmikH3trl7VPhVwnPZcFXR81jeTtFWzo1ayinzGhpUYhKsinQIypfLHHgcI69h8Z57uA4ew+O87vDWYarui2qH3P5mftaM6kWVndkWLGk/Kf0iiWtrF/ZTkcmXRW8R//JvawmiNtbU6QUqiKhUaAvQsWS8+LoBM8P5+gfLj8ODOd4fijHvkNZnjs4zoGh7BGt5nSLsW5FexDKadZ1th8RwuVgbmV1RyurlmZY3ZGZelyaSamLQCQGFOhNlssXOTQ+We7aqO7WCLo6RnKFI0ctVD1OFEocHs8zODpxVL9yi0HX8jbWr1xC96ZVnLZ6PRtXLWXD6iVsXLWUU1a06yKRSMIp0I+Tu/NE/wj37urn3t3P89j+4RmPXbGkleXt6aoLatOPSzPl/eesb2VdZztrO9tY29nO2s521q1o56SOjAJbRGalQJ+HYsnZ+ewh7t31PPfu7ue5g+MAnHfaSj705i2s62w/smujI8PKJa0KZBE5oRTox2BgOMdn732S+x7v5+DYJJlUC6858yT+4g0v4Q9ffjInd7aHXaKIJJgCvUG/fvog13/vN4zk8rzl7HVcvHUdb3hZF8vadApFZHFQGs3B3fn6z5/m0z/6LaevXsp3rr2Ql61bHnZZIiJHUaDPYnSiwMfufoT/ffQAbzl7LZ9597ksb28NuywRkboU6DPoGxjh/d/eydMvjHHjpWfx/tefobHaIrKoKdDr2P7oAT76g4dpb03xnWsv5DVnrgm7JBGROSnQq7g7//ij3/LVn+3hvNNW8qWrX8UpK5aEXZaISEMU6FW+/vOn+erP9nD1hafxybedfdzTnoqILCQFeqDnmYP8449+y8Vb1/IPbz9H/eUiEjlqggKDIxNc/73fsGHVEj7znnMV5iISSYkP9EKxxIfueJDD43m+dPX5dGpYoohEVOK7XD5335P8Ys+LfObd57L11M6wyxERmbdEt9B/vLufL/30Ka66YCPvOn9D2OWIiByXxAb6cy+O8+G7HuKc9Z188m1nh12OiMhxS2Sg5/JFPvDdnQB8+erzaW9NhVyRiMjxS2Qf+s3bdrHrd8Pcfk03G1cvDbscEZGmSFwL/Qc9e7lzx16u/4OX8Kaz1oZdjohI0yQu0L/w417OP30VH77oZWGXIiLSVIkK9JFcnv2Hs7zprJNJtejmIRGJl0QF+lODYwBsOXlZyJWIiDRfogK9t38EgC1rteKQiMRPogK9b2CUTLqFjas0Ja6IxE9DgW5ml5jZE2bWZ2Y31nn9dDP7iZk9YmY/NbNFedtl78AoZ6zpIJ1K1O8xEUmIOZPNzFLArcClwFbgKjPbWnPYZ4BvufsrgFuATze70GboHRhRd4uIxFYjTdULgD533+Puk8CdwBU1x2wF7g+2H6jzeujGJwvsO5TVBVERia1GAn09sLfq+b5gX7WHgXcG2+8AlpvZScdfXvPsGRzDXSNcRCS+mtWZ/BHgDWb2IPAGYD9QrD3IzK4zsx4z6xkcHGzSRzemd6AywkWBLiLx1Eig7wc2Vj3fEOyb4u6/c/d3uvt5wCeCfYdr38jdb3P3bnfv7urqmn/V89DbP0q6xTj9pI4F/VwRkYXSSKDvALaY2WYzywBXAtuqDzCzNWZWea+PA7c3t8zj1zswyuY1HbRqhIuIxNSc6ebuBeAG4B7gceAud99lZreY2eXBYW8EnjCzJ4G1wKdOUL3z1jcwqu4WEYm1hqbPdfftwPaafTdVbd8N3N3c0ponly/y7ItjvO3cU8MuRUTkhElE/8PTL4xRcjhTI1xEJMYSEei9A6OAhiyKSLwlItD7+kdoMdi8RiNcRCS+EhHovQOjnH5Sh9YOFZFYS0ygq/9cROIu9oE+WSjxzAtj6j8XkdiLfaA/++IYhZJrDLqIxF7sA316hIumzRWReIt/oPePYgYv6VILXUTiLf6BPjDChlVLWJLRCBcRibfYB3rfwKi6W0QkEWId6IViiT2DGuEiIskQ60DfeyjLZLGkMegikgixDvTe/soqRepyEZH4i3egB0MW1UIXkSSIdaD3DYxy6op2lrU1NO27iEikxTrQewdGOFPdLSKSELEN9FLJgyGL6m4RkWSIbaDvP5wlly8p0EUkMWIb6L0DlREuCnQRSYb4Bnp/MMKlS33oIpIM8Q30gVFOXt7GiqWtYZciIrIgYh3o6m4RkSSJZaC7O339I5qUS0QSJZaBfmAox9hkUXeIikiixDLQp1cpUqCLSHLEM9A1KZeIJFAsA71vYJSTOjKs7siEXYqIyIKJZaD3Doyq/1xEEid2ge7u9PaPaMiiiCRO7AJ9cGSC4VyBM7sU6CKSLLEL9KkRLrogKiIJE79Ar4xwUR+6iCRMQ4FuZpeY2RNm1mdmN9Z5/TQze8DMHjSzR8zssuaX2pjegVE629N0LW8LqwQRkVDMGehmlgJuBS4FtgJXmdnWmsP+DrjL3c8DrgS+1OxCG9U/PMGpK5dgZmGVICISikZa6BcAfe6+x90ngTuBK2qOcaAz2F4B/K55JR6b4VyeFUs0w6KIJE8jgb4e2Fv1fF+wr9rNwJ+Y2T5gO/DBem9kZteZWY+Z9QwODs6j3LkNZ/N0KtBFJIGadVH0KuAb7r4BuAz4tpkd9d7ufpu7d7t7d1dXV5M++kgjuQKd7Qp0EUmeRgJ9P7Cx6vmGYF+1a4G7ANz9F0A7sKYZBR6rcgs9HcZHi4iEqpFA3wFsMbPNZpahfNFzW80xzwFvBjCzl1MO9BPTpzKLYskZmVALXUSSac5Ad/cCcANwD/A45dEsu8zsFjO7PDjsb4D3mdnDwB3ANe7uJ6romYzk8gDqQxeRRGqob8Ldt1O+2Fm976aq7d3Aa5tb2rEbzhYA6GxXl4uIJE+s7hQdDlroGrYoIkkUr0DPqstFRJIrXoFe6UPXRVERSaB4BXqlD13DFkUkgeIV6BrlIiIJFq9Az+Yxg2UZtdBFJHliFehD2TzL29K0tGimRRFJnlgF+nCuoO4WEUmseAV6VlPnikhyxSvQc3kNWRSRxIpXoGcLGrIoIokVr0BXC11EEixega7VikQkwWIT6IViibHJolroIpJYsQn04Zxu+xeRZItPoGc1MZeIJFt8Al3zuIhIwsUn0IOZFnVjkYgkVXwCfaqFrj50EUmm+AS6+tBFJOHiE+jqQxeRhItPoGcLtBh0ZFJhlyIiEorYBPpQcJeomeZCF5Fkik2gax4XEUm6+AR6Nq8RLiKSaPEJ9FxBY9BFJNHiE+hZdbmISLLFJ9DVhy4iCRefQNdqRSKScLEI9MlCiWxec6GLSLLFItBHdJeoiEg8An0oq4m5REQaCnQzu8TMnjCzPjO7sc7rnzezh4KvJ83scNMrncXUakXqchGRBJuzSWtmKeBW4CJgH7DDzLa5++7KMe7+11XHfxA47wTUOqOpmRbV5SIiCdZIC/0CoM/d97j7JHAncMUsx18F3NGM4hpVmWlRNxaJSJI1Eujrgb1Vz/cF+45iZqcDm4H7Z3j9OjPrMbOewcHBY611RpXVitTlIiJJ1uyLolcCd7t7sd6L7n6bu3e7e3dXV1fTPlSrFYmINBbo+4GNVc83BPvquZIF7m6Bch96usVY0qq50EUkuRoJ9B3AFjPbbGYZyqG9rfYgMzsLWAX8orklzm04p7nQRUTmDHR3LwA3APcAjwN3ufsuM7vFzC6vOvRK4E539xNT6syGsgU629XdIiLJ1lAKuvt2YHvNvptqnt/cvLKOzXCwWpGISJLF4k5RzbQoIhKXQNdqRSIiMQl0rVYkIhKTQNdqRSIi0Q/0XL7IRKGki6IikniRD/SRqZkW1YcuIskW+UAf1uIWIiJADAJ9anEL9aGLSMJFPtCHtVqRiAgQh0DXakUiIkAcAj2rxS1ERCAOga6LoiIiQBwCPVsgk2qhLR35b0VE5LhEPgXLc6GnNRe6iCRe9ANdt/2LiAAxCPShbJ7l6j8XEYl+oA/ntFqRiAjEINBHtFqRiAgQg0DXakUiImWRDnR3ZzirxS1ERCDigT5RKDFZLGkeFxERIh7ow5ppUURkSrQDXbf9i4hMiXSgD2W1WpGISEWkA316LnS10EVEoh3oOfWhi4hURDvQtVqRiMiUaAe6VisSEZkS7UDP5mlLt9Demgq7FBGR0EU70HOax0VEpCLagZ7VTIsiIhXRDnS10EVEpjQU6GZ2iZk9YWZ9ZnbjDMe8x8x2m9kuM/tec8usb0irFYmITJmzv8LMUsCtwEXAPmCHmW1z991Vx2wBPg681t0PmdnJJ6rgasPZPKef1LEQHyUisug10kK/AOhz9z3uPgncCVxRc8z7gFvd/RCAuw80t8z6tFqRiMi0RgJ9PbC36vm+YF+1lwIvNbP/M7Nfmtkl9d7IzK4zsx4z6xkcHJxfxYHyXOjqQxcRqWjWRdE0sAV4I3AV8DUzW1l7kLvf5u7d7t7d1dV1XB+YzRcplFyLW4iIBBoJ9P3AxqrnG4J91fYB29w97+5PA09SDvgTZjiru0RFRKo1Eug7gC1mttnMMsCVwLaaY/6LcuscM1tDuQtmT/PKPNr0XOjqQxcRgQYC3d0LwA3APcDjwF3uvsvMbjGzy4PD7gFeNLPdwAPAR939xRNVNGi1IhGRWg01b919O7C9Zt9NVdsOfDj4WhBDmgtdROQIkb1TdHoudHW5iIhAlAO9clFULXQRESDSgV5uoS9XC11EBIhyoOfytLe20JbWXOgiIhDlQM8WdFORiEiV6AZ6TjMtiohUi3agq4UuIjIluoGu1YpERI4Q2UAf0kyLIiJHiGygqw9dRORIkQz06bnQ1eUiIlIRyUAfmyxSck3MJSJSLZKBPqyJuUREjhLNQA8m5tKNRSIi06IZ6FqtSETkKBENdK1WJCJSK5KBPqTVikREjhLJQJ9eT1SBLiJSEc1AD/rQNRe6iMi0aAZ6Ls/STIrWVCTLFxE5ISKZiMNZ3fYvIlIrmoGey2sMuohIjWgGeragIYsiIjWiGeiaaVFE5CjRDXR1uYiIHCGSgT40ntdqRSIiNSIX6KWSMzJRUAtdRKRG5AJ9dLKAay50EZGjRC7QNTGXiEh9EQx0TZ0rIlJP9AJdi1uIiNQVvUDX8nMiInVFL9Bz6nIREamnoUA3s0vM7Akz6zOzG+u8fo2ZDZrZQ8HXnze/1LIhXRQVEalrzlQ0sxRwK3ARsA/YYWbb3H13zaHfd/cbTkCNR9i4aglvOXsty9oU6CIi1RpJxQuAPnffA2BmdwJXALWBviAuPnsdF5+9LoyPFhFZ1BrpclkP7K16vi/YV+uPzewRM7vbzDbWeyMzu87MesysZ3BwcB7liojITJp1UfR/gE3u/grgPuCb9Q5y99vcvdvdu7u6upr00SIiAo0F+n6gusW9Idg3xd1fdPeJ4Om/Auc3pzwREWlUI4G+A9hiZpvNLANcCWyrPsDMTql6ejnwePNKFBGRRsx5UdTdC2Z2A3APkAJud/ddZnYL0OPu24APmdnlQAE4CFxzAmsWEZE6zN1D+eDu7m7v6ekJ5bNFRKLKzHa6e3e91yJ3p6iIiNSnQBcRiYnQulzMbBB4dp7/fA3wQhPLaSbVNj+qbX5U2/xEubbT3b3uuO/QAv14mFnPTH1IYVNt86Pa5ke1zU9ca1OXi4hITCjQRURiIqqBflvYBcxCtc2Papsf1TY/sawtkn3oIiJytKi20EVEpIYCXUQkJiIX6HMthxcmM3vGzB4NluELdV4DM7vdzAbM7LGqfavN7D4z6w0eVy2i2m42s/1VyxheFlJtG83sATPbbWa7zOwvg/2hn7tZagv93JlZu5n92sweDmr7+2D/ZjP7VfDz+v1ggr/FUts3zOzpqvP2yoWurarGlJk9aGY/DJ7P77y5e2S+KE8O9hRwBpABHga2hl1XVX3PAGvCriOo5fXAq4DHqvb9M3BjsH0j8E+LqLabgY8sgvN2CvCqYHs58CSwdTGcu1lqC/3cAQYsC7ZbgV8BrwbuAq4M9n8F+MAiqu0bwLvC/j8X1PVh4HvAD4Pn8zpvUWuhTy2H5+6TQGU5PKnh7j+jPPNltSuYXnzkm8DbF7KmihlqWxTc/YC7/ybYHqE8FfR6FsG5m6W20HnZaPC0Nfhy4E3A3cH+sM7bTLUtCma2AfgjymtJYGbGPM9b1AK90eXwwuLAvWa208yuC7uYOta6+4Fg+3lgbZjF1HFDsIzh7WF1B1Uzs03AeZRbdIvq3NXUBovg3AXdBg8BA5RXLnsKOOzuheCQ0H5ea2tz98p5+1Rw3j5vZm1h1AZ8AfhboBQ8P4l5nreoBfpi9zp3fxVwKXC9mb0+7IJm4uW/5RZNKwX4MvAS4JXAAeCzYRZjZsuAfwf+yt2Hq18L+9zVqW1RnDt3L7r7KymvanYBcFYYddRTW5uZnQN8nHKNvwesBj620HWZ2VuBAXff2Yz3i1qgz7kcXpjcfX/wOAD8J+X/1ItJf2V1qeBxIOR6prh7f/BDVwK+RojnzsxaKQfmd939P4Ldi+Lc1attMZ27oJ7DwAPA7wMrzayykE7oP69VtV0SdGG5l5fP/DfCOW+vBS43s2codyG/Cfgi8zxvUQv0OZfDC4uZdZjZ8so2cDHw2Oz/asFtA94bbL8X+O8QazmCHbmM4TsI6dwF/ZdfBx53989VvRT6uZuptsVw7sysy8xWBttLgIso9/E/ALwrOCys81avtt9W/YI2yn3UC37e3P3j7r7B3TdRzrP73f1q5nvewr66O4+rwZdRvrr/FPCJsOupqusMyqNuHgZ2hV0bcAflP7/zlPvgrqXcN/cToBf4MbB6EdX2beBR4BHK4XlKSLW9jnJ3yiPAQ8HXZYvh3M1SW+jnDngF8GBQw2PATcH+M4BfA33AD4C2RVTb/cF5ewz4DsFImLC+gDcyPcplXudNt/6LiMRE1LpcRERkBgp0EZGYUKCLiMSEAl1EJCYU6CIiMaFAFxGJCQW6iEhM/D+Gdvkl1Yu87gAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(L(learn.recorder.values).itemgot(2));"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.983316957950592"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.recorder.values[-1][2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Going Deeper"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch train_loss valid_loss accuracy time \n",
"█\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/SageMaker/.env/fastai/lib/python3.6/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.112760 0.009624 0.997056 00:15 \n"
]
}
],
"source": [
"dls = ImageDataLoaders.from_folder(path)\n",
"learn = cnn_learner(dls, resnet18, pretrained=False,\n",
" loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(1, 0.1)"
]
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "fastai",
"language": "python",
"name": "fastai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment