Created
September 24, 2022 02:42
-
-
Save Astroneko404/993a1e8c71dd1073644a4a3a078ba936 to your computer and use it in GitHub Desktop.
CNN from scratch in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torchvision\n", | |
"import torchvision.transforms as transforms" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mps\n" | |
] | |
} | |
], | |
"source": [ | |
"# Device setting for mac\n", | |
"# Note:mps seems buggy with Jupyter kernel\n", | |
"device = torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")\n", | |
"# device = torch.device(\"cpu\")\n", | |
"print(device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Data Preparation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Load Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 64\n", | |
"n_classes = 10\n", | |
"learning_rate = 0.001\n", | |
"n_epochs = 6\n", | |
"\n", | |
"image_transform = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.5,), (0.5,))\n", | |
"])\n", | |
"\n", | |
"train_data = torchvision.datasets.MNIST(\n", | |
" root=\"./data\",\n", | |
" train=True,\n", | |
" transform=image_transform,\n", | |
" download=True\n", | |
")\n", | |
"\n", | |
"test_data = torchvision.datasets.MNIST(\n", | |
" root=\"./data\",\n", | |
" train=False,\n", | |
" transform=image_transform,\n", | |
" download=True\n", | |
")\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(\n", | |
" dataset=train_data,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True\n", | |
")\n", | |
"\n", | |
"test_loader = torch.utils.data.DataLoader(\n", | |
" dataset=test_data,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Check the Shape of Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([64, 1, 28, 28])\n", | |
"torch.Size([64])\n" | |
] | |
} | |
], | |
"source": [ | |
"dataiter = iter(train_loader)\n", | |
"image, label = dataiter.next()\n", | |
"\n", | |
"print(image.shape)\n", | |
"print(label.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"```\n", | |
"image:\n", | |
"-> 64 images in each batch\n", | |
"-> Each image is 28 x 28 pixels\n", | |
"\n", | |
"label:\n", | |
"-> 64 labels\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"![](https://user-images.githubusercontent.com/33112694/191980342-041fdc37-95a4-448a-ac1a-15c88d31dc50.jpeg)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CNN(nn.Module):\n", | |
" def __init__(self, n_classes) -> None:\n", | |
" super(CNN, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=18, kernel_size=5, stride=1, padding=2)\n", | |
" self.batch1 = nn.BatchNorm2d(18)\n", | |
" self.relu1 = nn.ReLU()\n", | |
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n", | |
"\n", | |
" self.conv2 = nn.Conv2d(in_channels=18, out_channels=6, kernel_size=5, stride=1, padding=2)\n", | |
" self.batch2 = nn.BatchNorm2d(6)\n", | |
" self.relu2 = nn.ReLU()\n", | |
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n", | |
" \n", | |
" self.relu3 = nn.ReLU()\n", | |
" self.fc1 = nn.Linear(in_features=294, out_features=50)\n", | |
" self.fc2 = nn.Linear(in_features=50, out_features=n_classes)\n", | |
"\n", | |
" self.softmax = nn.LogSoftmax(dim=1)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.batch1(x)\n", | |
" x = self.relu1(x)\n", | |
" x = self.pool1(x)\n", | |
"\n", | |
" x = self.conv2(x)\n", | |
" x = self.batch2(x)\n", | |
" x = self.relu2(x)\n", | |
" x = self.pool2(x)\n", | |
"\n", | |
" x = x.view(x.size(0), -1)\n", | |
" x = self.relu3(x)\n", | |
" x = self.fc1(x)\n", | |
"\n", | |
" x = self.fc2(x)\n", | |
" return self.softmax(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = CNN(n_classes).to(device)\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = torch.optim.Adam(\n", | |
" model.parameters(),\n", | |
" lr=learning_rate,\n", | |
" )\n", | |
"steps = len(train_loader)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [1/6], Loss: 0.2178\n", | |
"Epoch [2/6], Loss: 0.0233\n", | |
"Epoch [3/6], Loss: 0.0070\n", | |
"Epoch [4/6], Loss: 0.0097\n", | |
"Epoch [5/6], Loss: 0.0256\n", | |
"Epoch [6/6], Loss: 0.0341\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(n_epochs):\n", | |
" for i, (image, label) in enumerate(train_loader):\n", | |
" image = image.to(device)\n", | |
" label = label.to(device)\n", | |
"\n", | |
" output = model(image)\n", | |
" loss = criterion(output, label)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, n_epochs, loss.item()))\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy for number 0: 99.27401654566943%\n", | |
"Accuracy for number 1: 99.55502818154851%\n", | |
"Accuracy for number 2: 98.80832494125545%\n", | |
"Accuracy for number 3: 98.8908824009134%\n", | |
"Accuracy for number 4: 99.58918178705923%\n", | |
"Accuracy for number 5: 99.76019184652279%\n", | |
"Accuracy for number 6: 99.76343359242988%\n", | |
"Accuracy for number 7: 99.28172386272945%\n", | |
"Accuracy for number 8: 98.83780550333276%\n", | |
"Accuracy for number 9: 98.0164733568667%\n", | |
"Total accuracy: 99.17666666666666%\n" | |
] | |
} | |
], | |
"source": [ | |
"with torch.no_grad():\n", | |
" correct = [0 for _ in range(10)]\n", | |
" total = [0 for _ in range(10)]\n", | |
" for images, labels in train_loader:\n", | |
" images = images.to(device)\n", | |
" labels = labels.to(device)\n", | |
" output = model(images)\n", | |
" _, predicted = torch.max(output.data, 1)\n", | |
" for i in range(labels.size(0)):\n", | |
" total[labels[i]] += 1\n", | |
" correct[predicted[i]] += 1 if predicted[i] == labels[i] else 0\n", | |
"# total += label.size(0)\n", | |
"# correct += (predicted == labels).sum().item()\n", | |
"\n", | |
"for i in range(10):\n", | |
" print(f\"Accuracy for number {i}: {correct[i] * 100 / total[i]}%\")\n", | |
"print(f\"Total accuracy: {sum(correct) * 100 / sum(total)}%\")\n", | |
"# print(100 * correct / total)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment