Last active
January 7, 2022 23:55
-
-
Save nbertagnolli/35eb960d08c566523b4da599f6099b41 to your computer and use it in GitHub Desktop.
A notebook describing how to implement dropout in pytorch from scratch. For the associated article https://medium.com/p/67f08a87ccff/edit.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "5e5497d1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"from torchvision import datasets\n", | |
"from torchvision.transforms import ToTensor\n", | |
"import torchvision.transforms as T\n", | |
"import numpy as np\n", | |
"from sklearn.metrics import classification_report\n", | |
"import torch.optim as optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "d694e909", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 0],\n", | |
" [4, 0, 0]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Drop out\n", | |
"p = .5\n", | |
"A = torch.tensor([[1, 2, 3], [4, 5, 6]])\n", | |
"A.mul(torch.empty(A.size()).uniform_(0, 1) >= p)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "8a64df23", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 0, 0],\n", | |
" [4, 0, 0]])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Drop Connect\n", | |
"A = torch.tensor([[1, 2, 3], [4, 5, 6]])\n", | |
"A.mul(torch.empty(A.size()[1]).uniform_(0, 1) >= p)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "e0ed0474", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Dropout(torch.nn.Module):\n", | |
" \n", | |
" def __init__(self, p: float=0.5):\n", | |
" super(Dropout, self).__init__()\n", | |
" self.p = p\n", | |
" if self.p < 0 or self.p > 1:\n", | |
" raise ValueError(\"p must be a probability\")\n", | |
" \n", | |
" def forward(self, x):\n", | |
" if self.training:\n", | |
" x = x.mul(torch.empty(x.size()[1]).uniform_(0, 1) >= self.p)\n", | |
" return x\n", | |
" \n", | |
"class TrueDropout(torch.nn.Module):\n", | |
" \n", | |
" def __init__(self, p: float=0.5):\n", | |
" super(TrueDropout, self).__init__()\n", | |
" self.p = p\n", | |
" if self.p < 0 or self.p > 1:\n", | |
" raise ValueError(\"p must be a probability\")\n", | |
" \n", | |
" def forward(self, x):\n", | |
" if self.training:\n", | |
" x = x.mul(torch.empty(x.size()[1]).uniform_(0, 1) >= self.p) * (1 / (1 - self.p))\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "2579c923", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MNISTModel(torch.nn.Module):\n", | |
" \n", | |
" def __init__(self, dropout_layer):\n", | |
" super(MNISTModel, self).__init__()\n", | |
" self.layer_1 = nn.Linear(28 * 28, 512)\n", | |
" self.layer_2 = nn.Linear(512, 512)\n", | |
" self.layer_3 = nn.Linear(512, 10)\n", | |
" self.dropout = dropout_layer\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = x.view(-1, 28 * 28)\n", | |
" x = self.layer_1(x)\n", | |
" x = self.layer_2(x)\n", | |
" x = self.dropout(x)\n", | |
" output = self.layer_3(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "f1292ceb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create basic transforms to flatten the MNIST input\n", | |
"transforms = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])\n", | |
"\n", | |
"# Create the datasets and data loaders for training\n", | |
"mnist_data = datasets.MNIST(root=\"data\", train=True, transform = transforms, download = True)\n", | |
"mnist_data_test = datasets.MNIST(root=\"data\", train=False, transform=transforms, download = True)\n", | |
"\n", | |
"data_loader = torch.utils.data.DataLoader(mnist_data,\n", | |
" batch_size=16,\n", | |
" shuffle=True,\n", | |
" num_workers=4)\n", | |
"\n", | |
"test_loader = torch.utils.data.DataLoader(mnist_data,\n", | |
" batch_size=16,\n", | |
" shuffle=False,\n", | |
" num_workers=4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "8e3d7720", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# You can play with different dropout implementations by passing them to our \n", | |
"# MNISTModel. For example, to use PyTorch's dropout do \n", | |
"# model = MNISTModel(nn.Dropout(.5))\n", | |
"model = MNISTModel(Dropout(.5))\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.001)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "5a86dfb6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1, 2000] loss: 3.246\n", | |
"[2, 2000] loss: 3.117\n", | |
"Finished Training\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(2): # loop over the dataset multiple times\n", | |
"\n", | |
" running_loss = 0.0\n", | |
" for i, data in enumerate(data_loader, 0):\n", | |
" # get the inputs; data is a list of [inputs, labels]\n", | |
" inputs, labels = data\n", | |
"\n", | |
" # zero the parameter gradients\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" # forward + backward + optimize\n", | |
" outputs = model(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # print statistics\n", | |
" running_loss += loss.item()\n", | |
" if i % 2000 == 1999: # print every 2000 mini-batches\n", | |
" print('[%d, %5d] loss: %.3f' %\n", | |
" (epoch + 1, i + 1, running_loss / 2000))\n", | |
" running_loss = 0.0\n", | |
"\n", | |
"print('Finished Training')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "935d83f3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.eval()\n", | |
"labels = []\n", | |
"preds = []\n", | |
"with torch.no_grad():\n", | |
" for i, data in enumerate(test_loader, 0):\n", | |
" preds.append(torch.argmax(model(data[0]), axis=1))\n", | |
" labels.append(data[1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "591f4628", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"labels = torch.concat(labels)\n", | |
"preds = torch.concat(preds)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "5b110369", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 0 0.96 0.97 0.96 5923\n", | |
" 1 0.90 0.98 0.94 6742\n", | |
" 2 0.95 0.83 0.89 5958\n", | |
" 3 0.91 0.87 0.89 6131\n", | |
" 4 0.95 0.88 0.92 5842\n", | |
" 5 0.89 0.85 0.87 5421\n", | |
" 6 0.93 0.96 0.95 5918\n", | |
" 7 0.88 0.95 0.91 6265\n", | |
" 8 0.84 0.87 0.86 5851\n", | |
" 9 0.87 0.88 0.87 5949\n", | |
"\n", | |
" accuracy 0.91 60000\n", | |
" macro avg 0.91 0.90 0.90 60000\n", | |
"weighted avg 0.91 0.91 0.91 60000\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# Performance on the test set.\n", | |
"print(classification_report(labels, preds, output_dict=False))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9df67a91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment