Last active
May 3, 2023 21:40
-
-
Save rohitdavas/98aafeba03fc4402d51cecbd904744cb to your computer and use it in GitHub Desktop.
catastrohpic-forgetting.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "standard" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rohitdavas/98aafeba03fc4402d51cecbd904744cb/catastrohpic-forgetting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"from torchvision import datasets, transforms\n", | |
"from torch.utils.data import DataLoader\n", | |
"import numpy as np " | |
], | |
"metadata": { | |
"id": "5aS3W8gY562l" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# conver the labels to one hot on the 10 classes.\n", | |
"def one_hot_encoding(labels):\n", | |
" one_hot = np.zeros((len(labels), 10))\n", | |
" for i, label in enumerate(labels):\n", | |
" one_hot[i][label] = 1\n", | |
" return one_hot \n", | |
"\n", | |
"# Divide dataset into two parts based on labels\n", | |
"def split_half_labels(mnist_dataset):\n", | |
" x_0to4 = []\n", | |
" y_0to4 = []\n", | |
" x_5to9 = []\n", | |
" y_5to9 = []\n", | |
"\n", | |
" for data, label in mnist_dataset:\n", | |
" if label < 5: \n", | |
" x_0to4.append(data.numpy())\n", | |
" y_0to4.append(label)\n", | |
" elif label >= 5:\n", | |
" x_5to9.append(data.numpy())\n", | |
" y_5to9.append(label)\n", | |
"\n", | |
" x_0to4, y_0to4, x_5to9, y_5to9 = [np.array(x) for x in (x_0to4, y_0to4, x_5to9, y_5to9)]\n", | |
"\n", | |
" # convert to one hot encoding. \n", | |
" y_0to4 = one_hot_encoding(y_0to4)\n", | |
" y_5to9 = one_hot_encoding(y_5to9)\n", | |
"\n", | |
" print(f\"created a split dataset of shapes\")\n", | |
" print(f\"\"\"\n", | |
" x_0to4 : {x_0to4.shape}\n", | |
" y_0to4 : {y_0to4.shape}\n", | |
" x_5to9 : {x_5to9.shape}\n", | |
" y_5to9 : {y_5to9.shape}\n", | |
" \"\"\")\n", | |
"\n", | |
" return x_0to4, y_0to4, x_5to9, y_5to9 \n", | |
"\n", | |
"# a torch dataloader function from numpy dataset. \n", | |
"def dataloaderFromNumpy(x, y, batch_size, shuffle):\n", | |
" x, y = [torch.from_numpy(a) for a in (x, y)]\n", | |
" tensor_dataset = torch.utils.data.TensorDataset(x, y)\n", | |
" dataloader = DataLoader(tensor_dataset, batch_size = batch_size, shuffle=shuffle)\n", | |
"\n", | |
" print(f\"created a dataloader from x, y and batch size = {batch_size}\")\n", | |
" print(f\"\\tlen of dataloader = {len(dataloader)}\")\n", | |
" batch_x, batch_y = next(iter(dataloader))\n", | |
" print(f\"\\t sample batch shapes : {(batch_x.shape, batch_y.shape)}\")\n", | |
"\n", | |
" return dataloader\n", | |
" " | |
], | |
"metadata": { | |
"id": "1Df3HatjFhP5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Define data transforms\n", | |
"transform = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.1307,), (0.3081,))\n", | |
"])\n", | |
"\n", | |
"# Load MNIST dataset\n", | |
"mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)\n", | |
"mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)\n", | |
"\n", | |
"\n", | |
"# ------------\n", | |
"# datasets\n", | |
"# ------------\n", | |
"\n", | |
"# the idea is to divide the same dataset into two. \n", | |
"# the first one have only the labels from 0to4\n", | |
"# the second one have only the labels from 5to9\n", | |
"# the third one have the labels from 0to9 : but with a catch that only 10% of the data from first one\n", | |
"# is added the data for the second one.\n", | |
"\n", | |
"# why ? It is to prove that we can reduce the catastrophic forgetting drastically if we add a fraction \n", | |
"# of the old dataset to while finetuning the model on the new dataset. \n", | |
"# model_1 : trained on only the 0to4\n", | |
"# model_2 : finetuned version of model_1 on the dataset 2, showing drastic reduction in performance on dataset 1\n", | |
"# model_3 : finetuned version of model_1, on the dataset 2 + a fraction of dataset_1 , showing a comparable performance on dataset 1 still. \n", | |
"# thus showing the reduction in the catastrophic forgetting. \n", | |
"\n", | |
"# dataset 1 AND 2 \n", | |
"\n", | |
"# train \n", | |
"train_x_0to4, train_y_0to4, train_x_5to9, train_y_5to9 = split_half_labels(mnist_train)\n", | |
"# test \n", | |
"test_x_0to4, test_y_0to4, test_x_5to9, test_y_5to9 = split_half_labels(mnist_test)\n", | |
"\n", | |
"# dataset 3 \n", | |
"n = len(train_x_0to4)\n", | |
"fraction_to_mix = 0.1 \n", | |
"# take 10 % of the dataset from 0to4 # and stack it with the 5to9 \n", | |
"# better way would be to choose the indices from uniform sampling.\n", | |
"# np random choice, by default assumes uniform distribution.\n", | |
"indices_to_work_with = np.random.choice(n, size=int(fraction_to_mix*n), replace=False)\n", | |
"print(f\"no of images to be added from the dataset 1 = {len(indices_to_work_with)}\")\n", | |
"\n", | |
"train_x_mix = np.vstack([train_x_0to4[indices_to_work_with], train_x_5to9])\n", | |
"train_y_mix = np.vstack([train_y_0to4[indices_to_work_with], train_y_5to9])\n", | |
"\n", | |
"print(f\"shape of the train_x_mix : {train_x_mix.shape}, train_y_mix = {train_y_mix.shape}\")\n" | |
], | |
"metadata": { | |
"id": "Nuvopy-b9rfb", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "d7f924b4-9f63-48fa-e810-45e5206306f6" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"created a split dataset of shapes\n", | |
"\n", | |
" x_0to4 : (30596, 1, 28, 28)\n", | |
" y_0to4 : (30596, 10)\n", | |
" x_5to9 : (29404, 1, 28, 28)\n", | |
" y_5to9 : (29404, 10)\n", | |
" \n", | |
"created a split dataset of shapes\n", | |
"\n", | |
" x_0to4 : (5139, 1, 28, 28)\n", | |
" y_0to4 : (5139, 10)\n", | |
" x_5to9 : (4861, 1, 28, 28)\n", | |
" y_5to9 : (4861, 10)\n", | |
" \n", | |
"no of images to be added from the dataset 1 = 3059\n", | |
"shape of the train_x_mix : (32463, 1, 28, 28), train_y_mix = (32463, 10)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# create torch dataloader for easy training. \n", | |
"l_train_0to4 = dataloaderFromNumpy(train_x_0to4, train_y_0to4, 256, True)\n", | |
"l_test_0to4 = dataloaderFromNumpy(test_x_0to4, test_y_0to4, 256, False)\n", | |
"\n", | |
"l_train_5to9 = dataloaderFromNumpy(train_x_5to9, train_y_5to9, 256, True)\n", | |
"l_test_5to9 = dataloaderFromNumpy(test_x_5to9, test_y_5to9, 256, False)\n", | |
"\n", | |
"l_train_mix = dataloaderFromNumpy(train_x_mix, train_y_mix, 256, True)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cmvwmCgs0jO-", | |
"outputId": "9cb07d35-db26-4b4d-e3aa-a488f4361dfa" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"created a dataloader from x, y and batch size = 256\n", | |
"\tlen of dataloader = 120\n", | |
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n", | |
"created a dataloader from x, y and batch size = 256\n", | |
"\tlen of dataloader = 21\n", | |
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n", | |
"created a dataloader from x, y and batch size = 256\n", | |
"\tlen of dataloader = 115\n", | |
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n", | |
"created a dataloader from x, y and batch size = 256\n", | |
"\tlen of dataloader = 19\n", | |
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n", | |
"created a dataloader from x, y and batch size = 256\n", | |
"\tlen of dataloader = 127\n", | |
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"use_cuda = torch.cuda.is_available()\n", | |
"device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n", | |
"\n", | |
"# Define neural network architecture\n", | |
"class Net(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Net, self).__init__()\n", | |
"\n", | |
" self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, padding = \"same\")\n", | |
" self.batchnorm1 = nn.BatchNorm2d(num_features = 32)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.avgpool1 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n", | |
" \n", | |
" self.conv2 = nn.Conv2d(in_channels= 32, out_channels= 32, kernel_size= 5, padding = \"same\")\n", | |
" self.batchnorm2 = nn.BatchNorm2d(num_features= 32)\n", | |
" self.avgpool2 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n", | |
"\n", | |
" self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, padding = \"same\")\n", | |
" self.batchnorm3 = nn.BatchNorm2d(num_features = 64)\n", | |
" self.avgpool3 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n", | |
"\n", | |
" self.fc1 = nn.Linear(in_features= 64*3*3, out_features= 64)\n", | |
" self.batchnorm4 = nn.BatchNorm1d(num_features= 64)\n", | |
" self.fc2 = nn.Linear(in_features= 64, out_features= 10) \n", | |
" \n", | |
" \n", | |
" def forward(self, x):\n", | |
"\n", | |
" x = self.avgpool1(self.relu(self.batchnorm1(self.conv1(x))))\n", | |
" x = self.avgpool2(self.relu(self.batchnorm2(self.conv2(x))))\n", | |
" x = self.avgpool3(self.relu(self.batchnorm3(self.conv3(x))))\n", | |
" x = torch.flatten(x, 1)\n", | |
" x = self.relu(self.batchnorm4(self.fc1(x)))\n", | |
" x = self.fc2(x)\n", | |
" x = nn.Softmax(dim=1)(x)\n", | |
" return x\n", | |
"\n", | |
"# Define training function for the neural network\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"def train(model, dataloader, optimizer):\n", | |
" model.train()\n", | |
" for batch_idx, (data, target) in enumerate(dataloader):\n", | |
" data, target = data.to(device), target.to(device)\n", | |
" optimizer.zero_grad()\n", | |
" output = model(data)\n", | |
" loss = criterion(output, target)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
"# Define evaluation function for the neural network\n", | |
"def evaluate(model, dataloader):\n", | |
" model.eval()\n", | |
" correct = 0\n", | |
" with torch.no_grad():\n", | |
" for data, target in dataloader:\n", | |
" data, target = data.to(device), target.to(device)\n", | |
" output = model(data)\n", | |
" pred = output.argmax(dim=1, keepdim=True)\n", | |
" target = target.argmax(dim=1, keepdim=True)\n", | |
" # print(pred, target)\n", | |
" correct += pred.eq(target.view_as(pred)).sum().item()\n", | |
" accuracy = 100. * correct / len(dataloader.dataset)\n", | |
" return accuracy" | |
], | |
"metadata": { | |
"id": "-EUcIQNe98pK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Train the neural network on the first dataset : 0to4 \n", | |
"model_0to4 = Net().to(device)\n", | |
"optimizer = optim.Adam(model_0to4.parameters())\n", | |
"epochs = 5\n", | |
"for epoch in range(epochs):\n", | |
" train(model_0to4, l_train_0to4, optimizer)\n", | |
" accuracy = evaluate(model_0to4, l_test_0to4)\n", | |
" print('Epoch {} | Testing Accuracy: {:.2f}%'.format(epoch+1, accuracy))\n", | |
"\n", | |
"# save model to the disk \n", | |
"torch.save(model_0to4.state_dict(), \"model_0to4.pt\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HncOXfFp8Bvr", | |
"outputId": "e08fb160-1ab6-468d-a6c6-7ed2d5c24fa1" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1 | Testing Accuracy: 99.73%\n", | |
"Epoch 2 | Testing Accuracy: 99.86%\n", | |
"Epoch 3 | Testing Accuracy: 99.79%\n", | |
"Epoch 4 | Testing Accuracy: 99.82%\n", | |
"Epoch 5 | Testing Accuracy: 99.79%\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Fine-tune the trained model on the second dataset 5to9\n", | |
"optimizer = optim.Adam(model_0to4.parameters(), lr=0.0001)\n", | |
"for epoch in range(epochs):\n", | |
" train(model_0to4, l_train_5to9, optimizer)\n", | |
" accuracy_0to4 = evaluate(model_0to4, l_test_0to4)\n", | |
" accuracy_5to9 = evaluate(model_0to4, l_test_5to9)\n", | |
" print('Epoch {} | 0-4 Accuracy: {:.2f}% | 5-9 Accuracy: {:.2f}%'.format(epoch+1, accuracy_0to4, accuracy_5to9))\n", | |
" \n", | |
"print(f\"castastrophic forgetting = {accuracy - accuracy_0to4}\") " | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_UllSuva-BZ0", | |
"outputId": "af8cfc23-e67d-4b42-f0fa-905cab28a639" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1 | 0-4 Accuracy: 75.23% | 5-9 Accuracy: 52.15%\n", | |
"Epoch 2 | 0-4 Accuracy: 26.25% | 5-9 Accuracy: 79.16%\n", | |
"Epoch 3 | 0-4 Accuracy: 7.86% | 5-9 Accuracy: 98.97%\n", | |
"Epoch 4 | 0-4 Accuracy: 2.04% | 5-9 Accuracy: 99.26%\n", | |
"Epoch 5 | 0-4 Accuracy: 0.76% | 5-9 Accuracy: 99.40%\n", | |
"castastrophic forgetting = 99.02704806382565\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# load the parameters of the model trained only on the 0to4 labels. \n", | |
"model_0to4.load_state_dict(torch.load(\"model_0to4.pt\"))\n", | |
"\n", | |
"# the idea is to train the model on the mix dataset, and show that the catastrophic forgetting is reduced. \n", | |
"optimizer = optim.Adam(model_0to4.parameters(), lr=0.0001)\n", | |
"for epoch in range(epochs):\n", | |
" train(model_0to4, l_train_mix, optimizer)\n", | |
" accuracy_0to4 = evaluate(model_0to4, l_test_0to4)\n", | |
" accuracy_5to9 = evaluate(model_0to4, l_test_5to9)\n", | |
" print('Epoch {} | 0-4 Accuracy: {:.2f}% | 5-9 Accuracy: {:.2f}%'.format(epoch+1, accuracy_0to4, accuracy_5to9))\n", | |
"\n", | |
"print(f\"castastrophic forgetting = {accuracy - accuracy_0to4}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8ZnFqPewKuiV", | |
"outputId": "18b3e794-c0f4-4aaf-fa02-c1ff3e3b34d3" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1 | 0-4 Accuracy: 99.14% | 5-9 Accuracy: 67.54%\n", | |
"Epoch 2 | 0-4 Accuracy: 98.68% | 5-9 Accuracy: 78.98%\n", | |
"Epoch 3 | 0-4 Accuracy: 98.37% | 5-9 Accuracy: 98.40%\n", | |
"Epoch 4 | 0-4 Accuracy: 98.17% | 5-9 Accuracy: 99.01%\n", | |
"Epoch 5 | 0-4 Accuracy: 98.23% | 5-9 Accuracy: 99.16%\n", | |
"castastrophic forgetting = 1.5567230978789581\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment