Skip to content

Instantly share code, notes, and snippets.

@rohitdavas
Last active May 3, 2023 21:40
Show Gist options
  • Save rohitdavas/98aafeba03fc4402d51cecbd904744cb to your computer and use it in GitHub Desktop.
Save rohitdavas/98aafeba03fc4402d51cecbd904744cb to your computer and use it in GitHub Desktop.
catastrohpic-forgetting.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rohitdavas/98aafeba03fc4402d51cecbd904744cb/catastrohpic-forgetting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np "
],
"metadata": {
"id": "5aS3W8gY562l"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# conver the labels to one hot on the 10 classes.\n",
"def one_hot_encoding(labels):\n",
" one_hot = np.zeros((len(labels), 10))\n",
" for i, label in enumerate(labels):\n",
" one_hot[i][label] = 1\n",
" return one_hot \n",
"\n",
"# Divide dataset into two parts based on labels\n",
"def split_half_labels(mnist_dataset):\n",
" x_0to4 = []\n",
" y_0to4 = []\n",
" x_5to9 = []\n",
" y_5to9 = []\n",
"\n",
" for data, label in mnist_dataset:\n",
" if label < 5: \n",
" x_0to4.append(data.numpy())\n",
" y_0to4.append(label)\n",
" elif label >= 5:\n",
" x_5to9.append(data.numpy())\n",
" y_5to9.append(label)\n",
"\n",
" x_0to4, y_0to4, x_5to9, y_5to9 = [np.array(x) for x in (x_0to4, y_0to4, x_5to9, y_5to9)]\n",
"\n",
" # convert to one hot encoding. \n",
" y_0to4 = one_hot_encoding(y_0to4)\n",
" y_5to9 = one_hot_encoding(y_5to9)\n",
"\n",
" print(f\"created a split dataset of shapes\")\n",
" print(f\"\"\"\n",
" x_0to4 : {x_0to4.shape}\n",
" y_0to4 : {y_0to4.shape}\n",
" x_5to9 : {x_5to9.shape}\n",
" y_5to9 : {y_5to9.shape}\n",
" \"\"\")\n",
"\n",
" return x_0to4, y_0to4, x_5to9, y_5to9 \n",
"\n",
"# a torch dataloader function from numpy dataset. \n",
"def dataloaderFromNumpy(x, y, batch_size, shuffle):\n",
" x, y = [torch.from_numpy(a) for a in (x, y)]\n",
" tensor_dataset = torch.utils.data.TensorDataset(x, y)\n",
" dataloader = DataLoader(tensor_dataset, batch_size = batch_size, shuffle=shuffle)\n",
"\n",
" print(f\"created a dataloader from x, y and batch size = {batch_size}\")\n",
" print(f\"\\tlen of dataloader = {len(dataloader)}\")\n",
" batch_x, batch_y = next(iter(dataloader))\n",
" print(f\"\\t sample batch shapes : {(batch_x.shape, batch_y.shape)}\")\n",
"\n",
" return dataloader\n",
" "
],
"metadata": {
"id": "1Df3HatjFhP5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define data transforms\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
"])\n",
"\n",
"# Load MNIST dataset\n",
"mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)\n",
"mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)\n",
"\n",
"\n",
"# ------------\n",
"# datasets\n",
"# ------------\n",
"\n",
"# the idea is to divide the same dataset into two. \n",
"# the first one have only the labels from 0to4\n",
"# the second one have only the labels from 5to9\n",
"# the third one have the labels from 0to9 : but with a catch that only 10% of the data from first one\n",
"# is added the data for the second one.\n",
"\n",
"# why ? It is to prove that we can reduce the catastrophic forgetting drastically if we add a fraction \n",
"# of the old dataset to while finetuning the model on the new dataset. \n",
"# model_1 : trained on only the 0to4\n",
"# model_2 : finetuned version of model_1 on the dataset 2, showing drastic reduction in performance on dataset 1\n",
"# model_3 : finetuned version of model_1, on the dataset 2 + a fraction of dataset_1 , showing a comparable performance on dataset 1 still. \n",
"# thus showing the reduction in the catastrophic forgetting. \n",
"\n",
"# dataset 1 AND 2 \n",
"\n",
"# train \n",
"train_x_0to4, train_y_0to4, train_x_5to9, train_y_5to9 = split_half_labels(mnist_train)\n",
"# test \n",
"test_x_0to4, test_y_0to4, test_x_5to9, test_y_5to9 = split_half_labels(mnist_test)\n",
"\n",
"# dataset 3 \n",
"n = len(train_x_0to4)\n",
"fraction_to_mix = 0.1 \n",
"# take 10 % of the dataset from 0to4 # and stack it with the 5to9 \n",
"# better way would be to choose the indices from uniform sampling.\n",
"# np random choice, by default assumes uniform distribution.\n",
"indices_to_work_with = np.random.choice(n, size=int(fraction_to_mix*n), replace=False)\n",
"print(f\"no of images to be added from the dataset 1 = {len(indices_to_work_with)}\")\n",
"\n",
"train_x_mix = np.vstack([train_x_0to4[indices_to_work_with], train_x_5to9])\n",
"train_y_mix = np.vstack([train_y_0to4[indices_to_work_with], train_y_5to9])\n",
"\n",
"print(f\"shape of the train_x_mix : {train_x_mix.shape}, train_y_mix = {train_y_mix.shape}\")\n"
],
"metadata": {
"id": "Nuvopy-b9rfb",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d7f924b4-9f63-48fa-e810-45e5206306f6"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"created a split dataset of shapes\n",
"\n",
" x_0to4 : (30596, 1, 28, 28)\n",
" y_0to4 : (30596, 10)\n",
" x_5to9 : (29404, 1, 28, 28)\n",
" y_5to9 : (29404, 10)\n",
" \n",
"created a split dataset of shapes\n",
"\n",
" x_0to4 : (5139, 1, 28, 28)\n",
" y_0to4 : (5139, 10)\n",
" x_5to9 : (4861, 1, 28, 28)\n",
" y_5to9 : (4861, 10)\n",
" \n",
"no of images to be added from the dataset 1 = 3059\n",
"shape of the train_x_mix : (32463, 1, 28, 28), train_y_mix = (32463, 10)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# create torch dataloader for easy training. \n",
"l_train_0to4 = dataloaderFromNumpy(train_x_0to4, train_y_0to4, 256, True)\n",
"l_test_0to4 = dataloaderFromNumpy(test_x_0to4, test_y_0to4, 256, False)\n",
"\n",
"l_train_5to9 = dataloaderFromNumpy(train_x_5to9, train_y_5to9, 256, True)\n",
"l_test_5to9 = dataloaderFromNumpy(test_x_5to9, test_y_5to9, 256, False)\n",
"\n",
"l_train_mix = dataloaderFromNumpy(train_x_mix, train_y_mix, 256, True)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cmvwmCgs0jO-",
"outputId": "9cb07d35-db26-4b4d-e3aa-a488f4361dfa"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"created a dataloader from x, y and batch size = 256\n",
"\tlen of dataloader = 120\n",
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n",
"created a dataloader from x, y and batch size = 256\n",
"\tlen of dataloader = 21\n",
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n",
"created a dataloader from x, y and batch size = 256\n",
"\tlen of dataloader = 115\n",
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n",
"created a dataloader from x, y and batch size = 256\n",
"\tlen of dataloader = 19\n",
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n",
"created a dataloader from x, y and batch size = 256\n",
"\tlen of dataloader = 127\n",
"\t sample batch shapes : (torch.Size([256, 1, 28, 28]), torch.Size([256, 10]))\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"use_cuda = torch.cuda.is_available()\n",
"device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
"\n",
"# Define neural network architecture\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, padding = \"same\")\n",
" self.batchnorm1 = nn.BatchNorm2d(num_features = 32)\n",
" self.relu = nn.ReLU()\n",
" self.avgpool1 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n",
" \n",
" self.conv2 = nn.Conv2d(in_channels= 32, out_channels= 32, kernel_size= 5, padding = \"same\")\n",
" self.batchnorm2 = nn.BatchNorm2d(num_features= 32)\n",
" self.avgpool2 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n",
"\n",
" self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, padding = \"same\")\n",
" self.batchnorm3 = nn.BatchNorm2d(num_features = 64)\n",
" self.avgpool3 = nn.AvgPool2d(kernel_size = 2, stride= 2, padding= 0)\n",
"\n",
" self.fc1 = nn.Linear(in_features= 64*3*3, out_features= 64)\n",
" self.batchnorm4 = nn.BatchNorm1d(num_features= 64)\n",
" self.fc2 = nn.Linear(in_features= 64, out_features= 10) \n",
" \n",
" \n",
" def forward(self, x):\n",
"\n",
" x = self.avgpool1(self.relu(self.batchnorm1(self.conv1(x))))\n",
" x = self.avgpool2(self.relu(self.batchnorm2(self.conv2(x))))\n",
" x = self.avgpool3(self.relu(self.batchnorm3(self.conv3(x))))\n",
" x = torch.flatten(x, 1)\n",
" x = self.relu(self.batchnorm4(self.fc1(x)))\n",
" x = self.fc2(x)\n",
" x = nn.Softmax(dim=1)(x)\n",
" return x\n",
"\n",
"# Define training function for the neural network\n",
"criterion = nn.CrossEntropyLoss()\n",
"def train(model, dataloader, optimizer):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(dataloader):\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"# Define evaluation function for the neural network\n",
"def evaluate(model, dataloader):\n",
" model.eval()\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in dataloader:\n",
" data, target = data.to(device), target.to(device)\n",
" output = model(data)\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" target = target.argmax(dim=1, keepdim=True)\n",
" # print(pred, target)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
" accuracy = 100. * correct / len(dataloader.dataset)\n",
" return accuracy"
],
"metadata": {
"id": "-EUcIQNe98pK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Train the neural network on the first dataset : 0to4 \n",
"model_0to4 = Net().to(device)\n",
"optimizer = optim.Adam(model_0to4.parameters())\n",
"epochs = 5\n",
"for epoch in range(epochs):\n",
" train(model_0to4, l_train_0to4, optimizer)\n",
" accuracy = evaluate(model_0to4, l_test_0to4)\n",
" print('Epoch {} | Testing Accuracy: {:.2f}%'.format(epoch+1, accuracy))\n",
"\n",
"# save model to the disk \n",
"torch.save(model_0to4.state_dict(), \"model_0to4.pt\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HncOXfFp8Bvr",
"outputId": "e08fb160-1ab6-468d-a6c6-7ed2d5c24fa1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1 | Testing Accuracy: 99.73%\n",
"Epoch 2 | Testing Accuracy: 99.86%\n",
"Epoch 3 | Testing Accuracy: 99.79%\n",
"Epoch 4 | Testing Accuracy: 99.82%\n",
"Epoch 5 | Testing Accuracy: 99.79%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Fine-tune the trained model on the second dataset 5to9\n",
"optimizer = optim.Adam(model_0to4.parameters(), lr=0.0001)\n",
"for epoch in range(epochs):\n",
" train(model_0to4, l_train_5to9, optimizer)\n",
" accuracy_0to4 = evaluate(model_0to4, l_test_0to4)\n",
" accuracy_5to9 = evaluate(model_0to4, l_test_5to9)\n",
" print('Epoch {} | 0-4 Accuracy: {:.2f}% | 5-9 Accuracy: {:.2f}%'.format(epoch+1, accuracy_0to4, accuracy_5to9))\n",
" \n",
"print(f\"castastrophic forgetting = {accuracy - accuracy_0to4}\") "
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_UllSuva-BZ0",
"outputId": "af8cfc23-e67d-4b42-f0fa-905cab28a639"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1 | 0-4 Accuracy: 75.23% | 5-9 Accuracy: 52.15%\n",
"Epoch 2 | 0-4 Accuracy: 26.25% | 5-9 Accuracy: 79.16%\n",
"Epoch 3 | 0-4 Accuracy: 7.86% | 5-9 Accuracy: 98.97%\n",
"Epoch 4 | 0-4 Accuracy: 2.04% | 5-9 Accuracy: 99.26%\n",
"Epoch 5 | 0-4 Accuracy: 0.76% | 5-9 Accuracy: 99.40%\n",
"castastrophic forgetting = 99.02704806382565\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# load the parameters of the model trained only on the 0to4 labels. \n",
"model_0to4.load_state_dict(torch.load(\"model_0to4.pt\"))\n",
"\n",
"# the idea is to train the model on the mix dataset, and show that the catastrophic forgetting is reduced. \n",
"optimizer = optim.Adam(model_0to4.parameters(), lr=0.0001)\n",
"for epoch in range(epochs):\n",
" train(model_0to4, l_train_mix, optimizer)\n",
" accuracy_0to4 = evaluate(model_0to4, l_test_0to4)\n",
" accuracy_5to9 = evaluate(model_0to4, l_test_5to9)\n",
" print('Epoch {} | 0-4 Accuracy: {:.2f}% | 5-9 Accuracy: {:.2f}%'.format(epoch+1, accuracy_0to4, accuracy_5to9))\n",
"\n",
"print(f\"castastrophic forgetting = {accuracy - accuracy_0to4}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8ZnFqPewKuiV",
"outputId": "18b3e794-c0f4-4aaf-fa02-c1ff3e3b34d3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1 | 0-4 Accuracy: 99.14% | 5-9 Accuracy: 67.54%\n",
"Epoch 2 | 0-4 Accuracy: 98.68% | 5-9 Accuracy: 78.98%\n",
"Epoch 3 | 0-4 Accuracy: 98.37% | 5-9 Accuracy: 98.40%\n",
"Epoch 4 | 0-4 Accuracy: 98.17% | 5-9 Accuracy: 99.01%\n",
"Epoch 5 | 0-4 Accuracy: 98.23% | 5-9 Accuracy: 99.16%\n",
"castastrophic forgetting = 1.5567230978789581\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment