Skip to content

Instantly share code, notes, and snippets.

@rafaelvareto
Last active January 16, 2024 01:21
Show Gist options
  • Save rafaelvareto/b92d59d5dab1d0bfcf1e495fdcc4eb26 to your computer and use it in GitHub Desktop.
Save rafaelvareto/b92d59d5dab1d0bfcf1e495fdcc4eb26 to your computer and use it in GitHub Desktop.
openloss_tutorial.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rafaelvareto/b92d59d5dab1d0bfcf1e495fdcc4eb26/openloss_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZX7zSI18QxEv"
},
"source": [
"# Open-set evaluation on E-MNIST and K-MNIST with PyTorch\n",
"\n",
"This notebook is a simple example of how Maximal-Entropy Loss (MEL) outperforms the conventional Cross-Entropy Loss in open-set classification tasks. As thoroughly described in two peer-reviewed papers ([journal](https://doi.org/10.1016/j.imavis.2023.104862),[conference](https://doi.org/10.1109/SIBGRAPI59091.2023.10347168)), MEL increases the entropy for negative samples and attaches a penalty to known target classes in pursuance of gallery specialization.\n",
"\n",
"> Cross-Entropy Loss requires an additional class encompassing negative samples during training time (27th category) whereas Maximal-Entropy Loss expects negative samples to hold negative target ids (label < 0).\n",
"\n",
"> The following code explains how MEL can be used with the PyTorch framework.\n",
"* Letters from E-MNIST dataset are employed as the target known classes (gallery set).\n",
"* Digits from MNIST dataset are used as negative training samples.\n",
"* Korean digits from K-MNIST provides unknown samples popping up during evaluation time."
]
},
{
"cell_type": "code",
"source": [
"#@title Installing Dependencies\n",
"\n",
"!pip install bob.measure\n",
"!pip install openloss"
],
"metadata": {
"id": "cGwxg9bYpdmB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bGU6NwlsXFSt"
},
"outputs": [],
"source": [
"#@title Import Dependencies\n",
"\n",
"import bob.measure\n",
"import openloss\n",
"import torch\n",
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_bNfVLRUYqZA"
},
"outputs": [],
"source": [
"#@title Define Hyperparameters\n",
"\n",
"input_size = 784\n",
"hidden_size = 1024\n",
"\n",
"batch_size = 100\n",
"learn_rate = 3e-4\n",
"num_epochs = 15"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lCsBCXMwbpH5"
},
"outputs": [],
"source": [
"#@title Downloading E-MNIST & K-MNIST data\n",
"\n",
"# Training data\n",
"train_data = torchvision.datasets.EMNIST(\n",
" root='./data', split='letters', train=True, transform=torchvision.transforms.ToTensor(), download=True\n",
")\n",
"train_neg_data_cel = torchvision.datasets.EMNIST(\n",
" root='./data', split='mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True,\n",
" target_transform=lambda _:len(train_data.classes)\n",
")\n",
"train_neg_data_mel = torchvision.datasets.EMNIST(\n",
" root='./data', split='mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True,\n",
" target_transform=lambda _:-1\n",
")\n",
"\n",
"# Evaluation data\n",
"probe_data = torchvision.datasets.EMNIST(\n",
" root='./data', split='letters', train=False, transform=torchvision.transforms.ToTensor(), download=True\n",
")\n",
"probe_unk_data_cel = torchvision.datasets.KMNIST(\n",
" root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True,\n",
" target_transform=lambda _:len(train_data.classes)\n",
")\n",
"probe_unk_data_mel = torchvision.datasets.KMNIST(\n",
" root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True,\n",
" target_transform=lambda _:-1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rfDPBdnYgfGp"
},
"outputs": [],
"source": [
"#@title Concatenating & Loading datasets\n",
"\n",
"train_loader_cel = torch.utils.data.DataLoader(\n",
" dataset=torch.utils.data.ConcatDataset([train_data, train_neg_data_cel]),\n",
" batch_size=batch_size, shuffle=True\n",
")\n",
"train_loader_mel = torch.utils.data.DataLoader(\n",
" dataset=torch.utils.data.ConcatDataset([train_data, train_neg_data_mel]),\n",
" batch_size=batch_size, shuffle=True\n",
")\n",
"\n",
"probe_loader_cel = torch.utils.data.DataLoader(\n",
" dataset=torch.utils.data.ConcatDataset([probe_data, probe_unk_data_cel]),\n",
" batch_size=batch_size, shuffle=False\n",
")\n",
"probe_loader_mel = torch.utils.data.DataLoader(\n",
" dataset=torch.utils.data.ConcatDataset([probe_data, probe_unk_data_mel]),\n",
" batch_size=batch_size, shuffle=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fL-YXTvghaz_"
},
"outputs": [],
"source": [
"#@title Define Model and MEL classes\n",
"\n",
"class NeuralNetwork(torch.nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_classes):\n",
" super(NeuralNetwork, self).__init__()\n",
" self.fc01 = torch.nn.Linear(input_size, hidden_size)\n",
" self.relu = torch.nn.ReLU()\n",
" self.fc02 = torch.nn.Linear(hidden_size, num_classes)\n",
"\n",
" def forward(self, x):\n",
" out = self.fc01(x)\n",
" out = self.relu(out)\n",
" out = self.fc02(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-3EPEqbjjfAT"
},
"outputs": [],
"source": [
"#@title Build the model\n",
"\n",
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"num_classes = len(train_data.classes)\n",
"num_samples = len(train_data) + len(train_neg_data_cel)\n",
"\n",
"model_cel = NeuralNetwork(input_size, hidden_size, num_classes+1).to(device)\n",
"model_mel = NeuralNetwork(input_size, hidden_size, num_classes).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ePLIwvAFj2zH"
},
"outputs": [],
"source": [
"#@title Define Cost functions & Optimizers\n",
"\n",
"criterion_cel = torch.nn.CrossEntropyLoss(reduction='mean')\n",
"criterion_mel = openloss.MaximalEntropyLoss(num_classes=num_classes, margin=0.5, reduction='mean')\n",
"\n",
"optimizer_cel = torch.optim.Adam(model_cel.parameters(), lr=learn_rate)\n",
"optimizer_mel = torch.optim.Adam(model_mel.parameters(), lr=learn_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u75Xa5VckuTH",
"outputId": "ddf7e03d-44f3-40ae-e314-404f1ceb4afa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training network with Cross-Entropy Loss\n",
"Epoch [1/15]\t[300/1848 = 1.3191] [600/1848 = 1.0328] [900/1848 = 0.8079] [1200/1848 = 0.6967] [1500/1848 = 0.6887] [1800/1848 = 0.5803] \n",
"Epoch [2/15]\t[300/1848 = 0.6316] [600/1848 = 0.6248] [900/1848 = 0.6778] [1200/1848 = 0.4207] [1500/1848 = 0.5554] [1800/1848 = 0.4496] \n",
"Epoch [3/15]\t[300/1848 = 0.4082] [600/1848 = 0.3918] [900/1848 = 0.4528] [1200/1848 = 0.3786] [1500/1848 = 0.3951] [1800/1848 = 0.4657] \n",
"Epoch [4/15]\t[300/1848 = 0.3174] [600/1848 = 0.5494] [900/1848 = 0.5841] [1200/1848 = 0.4429] [1500/1848 = 0.4614] [1800/1848 = 0.3229] \n",
"Epoch [5/15]\t[300/1848 = 0.3562] [600/1848 = 0.4165] [900/1848 = 0.4437] [1200/1848 = 0.4330] [1500/1848 = 0.3143] [1800/1848 = 0.5487] \n",
"Epoch [6/15]\t[300/1848 = 0.2810] [600/1848 = 0.2796] [900/1848 = 0.4146] [1200/1848 = 0.2920] [1500/1848 = 0.3870] [1800/1848 = 0.3592] \n",
"Epoch [7/15]\t[300/1848 = 0.3491] [600/1848 = 0.4434] [900/1848 = 0.2464] [1200/1848 = 0.3436] [1500/1848 = 0.2543] [1800/1848 = 0.2565] \n",
"Epoch [8/15]\t[300/1848 = 0.3977] [600/1848 = 0.2892] [900/1848 = 0.3723] [1200/1848 = 0.2756] [1500/1848 = 0.3353] [1800/1848 = 0.3320] \n",
"Epoch [9/15]\t[300/1848 = 0.3092] [600/1848 = 0.3115] [900/1848 = 0.3045] [1200/1848 = 0.2261] [1500/1848 = 0.2911] [1800/1848 = 0.2054] \n",
"Epoch [10/15]\t[300/1848 = 0.1486] [600/1848 = 0.4474] [900/1848 = 0.1663] [1200/1848 = 0.2507] [1500/1848 = 0.2484] [1800/1848 = 0.5103] \n",
"Epoch [11/15]\t[300/1848 = 0.2425] [600/1848 = 0.2604] [900/1848 = 0.2603] [1200/1848 = 0.2777] [1500/1848 = 0.3113] [1800/1848 = 0.2458] \n",
"Epoch [12/15]\t[300/1848 = 0.2184] [600/1848 = 0.1910] [900/1848 = 0.1943] [1200/1848 = 0.3508] [1500/1848 = 0.2591] [1800/1848 = 0.2616] \n",
"Epoch [13/15]\t[300/1848 = 0.2858] [600/1848 = 0.1947] [900/1848 = 0.1665] [1200/1848 = 0.2220] [1500/1848 = 0.2763] [1800/1848 = 0.1566] \n",
"Epoch [14/15]\t[300/1848 = 0.2674] [600/1848 = 0.2125] [900/1848 = 0.1449] [1200/1848 = 0.2455] [1500/1848 = 0.2159] [1800/1848 = 0.2552] \n",
"Epoch [15/15]\t[300/1848 = 0.2064] [600/1848 = 0.2750] [900/1848 = 0.3080] [1200/1848 = 0.2190] [1500/1848 = 0.2955] [1800/1848 = 0.1682] \n",
"Training network with Maximal-Entropy Loss\n",
"Epoch [1/15]\t[300/1848 = 2.2881] [600/1848 = 2.1260] [900/1848 = 2.0720] [1200/1848 = 1.9765] [1500/1848 = 1.8796] [1800/1848 = 1.7293] \n",
"Epoch [2/15]\t[300/1848 = 1.9387] [600/1848 = 1.7789] [900/1848 = 1.6595] [1200/1848 = 1.4682] [1500/1848 = 1.6029] [1800/1848 = 1.9263] \n",
"Epoch [3/15]\t[300/1848 = 1.6872] [600/1848 = 1.8431] [900/1848 = 1.3831] [1200/1848 = 2.2339] [1500/1848 = 1.7945] [1800/1848 = 1.4387] \n",
"Epoch [4/15]\t[300/1848 = 1.3946] [600/1848 = 1.9433] [900/1848 = 1.4671] [1200/1848 = 1.7856] [1500/1848 = 1.5913] [1800/1848 = 1.4531] \n",
"Epoch [5/15]\t[300/1848 = 1.1340] [600/1848 = 1.3471] [900/1848 = 1.5600] [1200/1848 = 1.3287] [1500/1848 = 1.2284] [1800/1848 = 1.4017] \n",
"Epoch [6/15]\t[300/1848 = 1.4419] [600/1848 = 1.3205] [900/1848 = 1.3468] [1200/1848 = 1.6891] [1500/1848 = 1.3360] [1800/1848 = 1.5187] \n",
"Epoch [7/15]\t[300/1848 = 1.3123] [600/1848 = 1.3725] [900/1848 = 1.4318] [1200/1848 = 1.3849] [1500/1848 = 1.6194] [1800/1848 = 1.4615] \n",
"Epoch [8/15]\t[300/1848 = 1.1954] [600/1848 = 1.4519] [900/1848 = 1.6195] [1200/1848 = 1.3001] [1500/1848 = 1.3851] [1800/1848 = 1.3715] \n",
"Epoch [9/15]\t[300/1848 = 1.5820] [600/1848 = 1.2637] [900/1848 = 1.3042] [1200/1848 = 1.5714] [1500/1848 = 1.2808] [1800/1848 = 1.5628] \n",
"Epoch [10/15]\t[300/1848 = 1.6933] [600/1848 = 1.3179] [900/1848 = 1.3250] [1200/1848 = 1.3822] [1500/1848 = 1.3070] [1800/1848 = 1.5055] \n",
"Epoch [11/15]\t[300/1848 = 1.1019] [600/1848 = 1.6998] [900/1848 = 1.4937] [1200/1848 = 1.3926] [1500/1848 = 1.4391] [1800/1848 = 1.4541] \n",
"Epoch [12/15]\t[300/1848 = 1.6115] [600/1848 = 1.4365] [900/1848 = 1.3045] [1200/1848 = 1.4203] [1500/1848 = 1.5660] [1800/1848 = 1.3486] \n",
"Epoch [13/15]\t[300/1848 = 1.3794] [600/1848 = 1.3641] [900/1848 = 1.4275] [1200/1848 = 1.5101] [1500/1848 = 1.5873] [1800/1848 = 1.4084] \n",
"Epoch [14/15]\t[300/1848 = 1.2511] [600/1848 = 1.3520] [900/1848 = 1.3668] [1200/1848 = 1.1975] [1500/1848 = 1.2604] [1800/1848 = 1.5201] \n",
"Epoch [15/15]\t[300/1848 = 1.2493] [600/1848 = 1.5097] [900/1848 = 1.3566] [1200/1848 = 1.4217] [1500/1848 = 1.3549] [1800/1848 = 1.2613] \n"
]
}
],
"source": [
"#@title Define Training and Evaluating classes\n",
"\n",
"def train_network(loader, model, criterion, optimizer, num_epochs, num_samples, device):\n",
" for epoch_id in range(num_epochs):\n",
" print(f'Epoch [{epoch_id+1}/{num_epochs}]', end='\\t')\n",
" for batch_id ,(images,labels) in enumerate(loader):\n",
" images = images.view(-1,28*28).to(device)\n",
" labels = labels.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if (batch_id+1) % 300 == 0:\n",
" print(f'[{batch_id+1}/{num_samples//batch_size} = {loss.item():.4f}]', end=' ')\n",
" print()\n",
" return model\n",
"\n",
"print('Training network with Cross-Entropy Loss')\n",
"model_cel = train_network(train_loader_cel, model_cel, criterion_cel, optimizer_cel, num_epochs, num_samples, device)\n",
"print('Training network with Maximal-Entropy Loss')\n",
"model_mel = train_network(train_loader_mel, model_mel, criterion_mel, optimizer_mel, num_epochs, num_samples, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l6JRaN9ZIjdw"
},
"outputs": [],
"source": [
"#@title Evaluating the accuracy of the models\n",
"\n",
"def evaluate_network(loader, model, neg_class=-1, device='cpu'):\n",
" neg_pos_scores, overall_accuracy = list(), list()\n",
" for (images, labels) in loader:\n",
" images = images.view(-1,28*28).to(device)\n",
" labels = labels.to(device)\n",
"\n",
" results = model(images)\n",
" for (result, label) in zip(results, labels):\n",
" neg = [value.item() for (idx,value) in enumerate(result) if idx != label]\n",
" pos = result[label].item() if label != neg_class else None\n",
" neg_pos_scores.append( (neg, pos) )\n",
" return neg_pos_scores\n",
"\n",
"neg_pos_scores_cel = evaluate_network(probe_loader_cel, model_cel, len(train_data.classes), device=device)\n",
"neg_pos_scores_mel = evaluate_network(probe_loader_mel, model_mel, device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 477
},
"id": "Z8pZqkpQXYL0",
"outputId": "5f551f69-4dd9-4ef2-a9f4-d4736e91d14d"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot\n",
"\n",
"bob.measure.plot.detection_identification_curve(neg_pos_scores_cel, label='CEL-based')\n",
"bob.measure.plot.detection_identification_curve(neg_pos_scores_mel, label='MEL-based')\n",
"pyplot.rcParams.update({'font.size': 12})\n",
"pyplot.title('Detection and Identification Curve (Open ROC)')\n",
"pyplot.xlabel('False Positive Identification Rate')\n",
"pyplot.ylabel('True Positive Identification Rate')\n",
"pyplot.legend(loc='best')\n",
"pyplot.xlim([1e-1, 1.00])\n",
"pyplot.xticks([1e-1, 5e-1, 1.0])\n",
"pyplot.grid(True)\n",
"pyplot.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ttwFa4dLetrC"
},
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment