Skip to content

Instantly share code, notes, and snippets.

@Astroneko404
Created September 24, 2022 02:42
Show Gist options
  • Save Astroneko404/993a1e8c71dd1073644a4a3a078ba936 to your computer and use it in GitHub Desktop.
Save Astroneko404/993a1e8c71dd1073644a4a3a078ba936 to your computer and use it in GitHub Desktop.
CNN from scratch in PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"import torchvision.transforms as transforms"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mps\n"
]
}
],
"source": [
"# Device setting for mac\n",
"# Note:mps seems buggy with Jupyter kernel\n",
"device = torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")\n",
"# device = torch.device(\"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Preparation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 64\n",
"n_classes = 10\n",
"learning_rate = 0.001\n",
"n_epochs = 6\n",
"\n",
"image_transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
"])\n",
"\n",
"train_data = torchvision.datasets.MNIST(\n",
" root=\"./data\",\n",
" train=True,\n",
" transform=image_transform,\n",
" download=True\n",
")\n",
"\n",
"test_data = torchvision.datasets.MNIST(\n",
" root=\"./data\",\n",
" train=False,\n",
" transform=image_transform,\n",
" download=True\n",
")\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" dataset=train_data,\n",
" batch_size=batch_size,\n",
" shuffle=True\n",
")\n",
"\n",
"test_loader = torch.utils.data.DataLoader(\n",
" dataset=test_data,\n",
" batch_size=batch_size,\n",
" shuffle=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check the Shape of Data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([64, 1, 28, 28])\n",
"torch.Size([64])\n"
]
}
],
"source": [
"dataiter = iter(train_loader)\n",
"image, label = dataiter.next()\n",
"\n",
"print(image.shape)\n",
"print(label.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"image:\n",
"-> 64 images in each batch\n",
"-> Each image is 28 x 28 pixels\n",
"\n",
"label:\n",
"-> 64 labels\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](https://user-images.githubusercontent.com/33112694/191980342-041fdc37-95a4-448a-ac1a-15c88d31dc50.jpeg)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" def __init__(self, n_classes) -> None:\n",
" super(CNN, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=18, kernel_size=5, stride=1, padding=2)\n",
" self.batch1 = nn.BatchNorm2d(18)\n",
" self.relu1 = nn.ReLU()\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.conv2 = nn.Conv2d(in_channels=18, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
" self.batch2 = nn.BatchNorm2d(6)\n",
" self.relu2 = nn.ReLU()\n",
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
" \n",
" self.relu3 = nn.ReLU()\n",
" self.fc1 = nn.Linear(in_features=294, out_features=50)\n",
" self.fc2 = nn.Linear(in_features=50, out_features=n_classes)\n",
"\n",
" self.softmax = nn.LogSoftmax(dim=1)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.batch1(x)\n",
" x = self.relu1(x)\n",
" x = self.pool1(x)\n",
"\n",
" x = self.conv2(x)\n",
" x = self.batch2(x)\n",
" x = self.relu2(x)\n",
" x = self.pool2(x)\n",
"\n",
" x = x.view(x.size(0), -1)\n",
" x = self.relu3(x)\n",
" x = self.fc1(x)\n",
"\n",
" x = self.fc2(x)\n",
" return self.softmax(x)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = CNN(n_classes).to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=learning_rate,\n",
" )\n",
"steps = len(train_loader)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/6], Loss: 0.2178\n",
"Epoch [2/6], Loss: 0.0233\n",
"Epoch [3/6], Loss: 0.0070\n",
"Epoch [4/6], Loss: 0.0097\n",
"Epoch [5/6], Loss: 0.0256\n",
"Epoch [6/6], Loss: 0.0341\n"
]
}
],
"source": [
"for epoch in range(n_epochs):\n",
" for i, (image, label) in enumerate(train_loader):\n",
" image = image.to(device)\n",
" label = label.to(device)\n",
"\n",
" output = model(image)\n",
" loss = criterion(output, label)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, n_epochs, loss.item()))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for number 0: 99.27401654566943%\n",
"Accuracy for number 1: 99.55502818154851%\n",
"Accuracy for number 2: 98.80832494125545%\n",
"Accuracy for number 3: 98.8908824009134%\n",
"Accuracy for number 4: 99.58918178705923%\n",
"Accuracy for number 5: 99.76019184652279%\n",
"Accuracy for number 6: 99.76343359242988%\n",
"Accuracy for number 7: 99.28172386272945%\n",
"Accuracy for number 8: 98.83780550333276%\n",
"Accuracy for number 9: 98.0164733568667%\n",
"Total accuracy: 99.17666666666666%\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" correct = [0 for _ in range(10)]\n",
" total = [0 for _ in range(10)]\n",
" for images, labels in train_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" output = model(images)\n",
" _, predicted = torch.max(output.data, 1)\n",
" for i in range(labels.size(0)):\n",
" total[labels[i]] += 1\n",
" correct[predicted[i]] += 1 if predicted[i] == labels[i] else 0\n",
"# total += label.size(0)\n",
"# correct += (predicted == labels).sum().item()\n",
"\n",
"for i in range(10):\n",
" print(f\"Accuracy for number {i}: {correct[i] * 100 / total[i]}%\")\n",
"print(f\"Total accuracy: {sum(correct) * 100 / sum(total)}%\")\n",
"# print(100 * correct / total)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment