Skip to content

Instantly share code, notes, and snippets.

@tomekkorbak
Created November 15, 2020 19:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomekkorbak/bdea3fb841fcd390b58f2643eaaf365b to your computer and use it in GitHub Desktop.
Save tomekkorbak/bdea3fb841fcd390b58f2643eaaf365b to your computer and use it in GitHub Desktop.
triplet_quadruplet.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "triplet_quadruplet.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPi2p8uUids1lNFFAZtB5qo",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tomekkorbak/bdea3fb841fcd390b58f2643eaaf365b/triplet_quadruplet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MYKjW3HfoEtX"
},
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import torchvision.models as models\n",
"\n",
"from torch import nn\n",
"from torch.optim import Optimizer, SGD\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import CIFAR10\n",
"from torchvision.transforms.functional import to_tensor"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xGLw3DFJtRTp"
},
"source": [
"# Triplet loss"
]
},
{
"cell_type": "code",
"metadata": {
"id": "O6qixWM3mwR6"
},
"source": [
"def get_distance_matrix(\n",
" embeddings: torch.Tensor, # [B, E]\n",
" ):\n",
" \"\"\"Compute a distance matrix for image embeddings using two tricks:\n",
" 1. Quadratic expansion, the fact that ||a - b||^2 = ||a||^2 - 2*a*b + ||b||^2\n",
" 2. diag(X @ X.T) is ||x||^2 for each row x in X\n",
" https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065\n",
" \"\"\"\n",
" B = embeddings.size(0)\n",
" dot_product = embeddings @ embeddings.T # [B, B]\n",
" squared_norm = torch.diag(dot_product) # [B]\n",
" distances = squared_norm.view(1, B) - 2.0 * dot_product + squared_norm.view(B, 1) # [B, B]\n",
" return torch.sqrt(nn.functional.relu(distances) + 1e-16) # [B, B]\n",
"\n",
"\n",
"def get_positive_mask(\n",
" labels: torch.Tensor, # [B]\n",
" device: torch.device\n",
" ):\n",
" \"\"\"Compute a 2d mask for pairs (i, j) such that labels[i] == labels[j] and i != j\n",
" \"\"\"\n",
" B = labels.size(0)\n",
" labels_equal = labels.view(1, B) == labels.view(B, 1) # [B, B]\n",
" indices_equal = torch.eye(B, dtype=torch.bool).to(device=device) # [B, B]\n",
" return labels_equal & ~indices_equal # [B, B] \n",
"\n",
"\n",
"def get_negative_mask(\n",
" labels: torch.Tensor, # [B]\n",
" device: torch.device\n",
" ):\n",
" \"\"\"Compute a 2d mask for pairs (i, j) such that labels[i] != labels[j] and i != j\n",
" \"\"\"\n",
" B = labels.size(0)\n",
" labels_equal = labels.view(1, B) == labels.view(B, 1) # [B, B]\n",
" indices_equal = torch.eye(B, dtype=torch.bool).to(device=device) # [B, B] \n",
" return ~labels_equal & ~indices_equal # [B, B]\n",
"\n",
"\n",
"def get_triplet_mask(\n",
" labels: torch.Tensor, # [B]\n",
" device: torch.device\n",
" ):\n",
" \"\"\"Compute a 3d mask for triplets (i, j, k) such that:\n",
" labels[i] == labels[j] and labels[i] != labels[k] and i != j != k\n",
" \"\"\"\n",
" B = labels.size(0)\n",
"\n",
" # Make sure that i != j != k\n",
" indices_equal = torch.eye(B, dtype=torch.bool).to(device=device) # [B, B] \n",
" indices_not_equal = ~indices_equal # [B, B] \n",
" i_not_equal_j = indices_not_equal.view(B, B, 1) # [B, B, 1]\n",
" i_not_equal_k = indices_not_equal.view(B, 1, B) # [B, 1, B] \n",
" j_not_equal_k = indices_not_equal.view(1, B, B) # [1, B, B] \n",
" distinct_indices = i_not_equal_j & i_not_equal_k & j_not_equal_k # [B, B, B]\n",
"\n",
" # Make sure that labels[i] == labels[j] but labels[i] != labels[k]\n",
" labels_equal = labels.view(1, B) == labels.view(B, 1) # [B, B]\n",
" i_equal_j = labels_equal.view(B, B, 1) # [B, B, 1]\n",
" i_equal_k = labels_equal.view(B, 1, B) # [B, 1, B]\n",
" valid_labels = i_equal_j & ~i_equal_k # [B, B, B]\n",
" \n",
" return distinct_indices & valid_labels # [B, B, B]\n",
"\n",
"\n",
"def test_get_distance_matrix(device_for_tests):\n",
" embeddings = torch.FloatTensor(\n",
" [[1, 1], \n",
" [7, 7], \n",
" [1, 1]], \n",
" ).to(device=device_for_tests)\n",
" distance_matrix = get_distance_matrix(embeddings)\n",
" assert torch.allclose(\n",
" torch.diag(distance_matrix), \n",
" torch.zeros(3, device=device_for_tests)\n",
" )\n",
" assert torch.allclose(distance_matrix, distance_matrix.T)\n",
" assert distance_matrix[0, 2] < distance_matrix[0, 1]\n",
"\n",
"\n",
"def test_get_positive_mask(device_for_tests):\n",
" labels = torch.LongTensor([1, 2, 3, 1])\n",
" pos_mask = get_positive_mask(labels, device_for_tests)\n",
" assert pos_mask[0, 3]\n",
" assert not pos_mask[0, 1]\n",
" assert not pos_mask[0, 0] and not pos_mask[1, 1]\n",
"\n",
"\n",
"def test_get_negative_mask(device_for_tests):\n",
" labels = torch.LongTensor([1, 2, 3, 1])\n",
" neg_mask = get_negative_mask(labels, device_for_tests)\n",
" assert not neg_mask[0, 3]\n",
" assert neg_mask[0, 1]\n",
" assert not neg_mask[0, 0] and not neg_mask[1, 1]\n",
"\n",
"\n",
"def test_get_triplet_mask(device_for_tests):\n",
" labels = torch.LongTensor([1, 2, 3, 1, 3])\n",
" mask = get_triplet_mask(labels, device_for_tests)\n",
" assert mask[0, 3, 2]\n",
" assert mask[2, 4, 1]\n",
" assert mask[4, 2, 0]\n",
" assert not mask[0, 0, 0]\n",
" assert not mask[0, 3, 3]\n",
" assert not mask[0, 0, 4]\n",
"\n",
"device_for_tests = torch.device('cpu')\n",
"test_get_distance_matrix(device_for_tests)\n",
"test_get_positive_mask(device_for_tests)\n",
"test_get_negative_mask(device_for_tests)\n",
"test_get_triplet_mask(device_for_tests)"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uFeuHubVnKkM"
},
"source": [
"class TripletLossModel(nn.Module):\n",
" \n",
" def __init__(self, resnet: nn.Module):\n",
" super().__init__()\n",
" self.resnet = resnet\n",
" self.resnet.fc = nn.Identity()\n",
" self.embeddings = nn.Linear(512, 128)\n",
" \n",
" def forward(\n",
" self, \n",
" inputs: torch.Tensor, # [B, C, H, W]\n",
" labels: torch.Tensor # [B]\n",
" ):\n",
" B = labels.size(0)\n",
" embeddings = self.embeddings(self.resnet(inputs)) # [B, E]\n",
" distance_matrix = get_distance_matrix(embeddings) # [B, B]\n",
" with torch.no_grad():\n",
" mask_pos = get_positive_mask(labels, device) # [B, B]\n",
" mask_neg = get_negative_mask(labels, device) # [B, B]\n",
" triplet_mask = get_triplet_mask(labels, device) # [B, B, B]\n",
" quadruplet_mask = get_quadruplet_mask(labels, device) # [B, B, B, B]\n",
" unmasked_triplets = torch.sum(triplet_mask) # [1]\n",
" unmasked_quadruplets = torch.sum(quadruplet_mask) # [1]\n",
" mu_pos = torch.mean(distance_matrix[mask_pos]) # [1]\n",
" mu_neg = torch.mean(distance_matrix[mask_neg]) # [1]\n",
" mu = mu_neg - mu_pos # [1]\n",
" \n",
" distance_i_j = distance_matrix.view(B, B, 1) # [B, B, 1]\n",
" distance_i_k = distance_matrix.view(B, 1, B) # [B, 1, B]\n",
" triplet_loss_unmasked = distance_i_k - distance_i_j # [B, B, B]\n",
" triplet_loss_unmasked = triplet_loss_unmasked[triplet_mask] # [valid_triplets]\n",
" hardest_triplets = triplet_loss_unmasked < max(mu, 0) # [valid_triplets]\n",
" triplet_loss = triplet_loss_unmasked[hardest_triplets] # [valid_triplets_after_mask]\n",
" triplet_loss = nn.functional.relu(triplet_loss) # [valid_triplets_after_mask]\n",
"\n",
" loss = triplet_loss.mean()\n",
" logs = {\n",
" 'positive_pairs': torch.sum(mask_pos).cpu().detach().item(),\n",
" 'negative_pairs': torch.sum(mask_neg).cpu().detach().item(),\n",
" 'mu_neg': mu_neg.cpu().detach().item(),\n",
" 'mu_pos': mu_pos.cpu().detach().item(),\n",
" 'valid_triplets': unmasked_triplets.cpu().detach().item(),\n",
" 'valid_triplets_after_mask': triplet_loss.size(0),\n",
" 'triplet_loss': triplet_loss.mean().cpu().detach().item()\n",
" }\n",
" return loss, logs\n"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CbFpmU3JnLID"
},
"source": [
"device = torch.device('cuda')\n",
"resnet18 = models.resnet18(pretrained=False)\n",
"model = TripletLossModel(resnet=resnet18)\n",
"model = model.to(device)\n",
"opt = SGD(model.parameters(), lr=0.001)\n",
"ds_train = CIFAR10('.', transform=to_tensor, download=True)\n",
"dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=True)\n",
"for e in range(10):\n",
" for batch_idx, (input_, labels) in enumerate(dataloader_train):\n",
" input_, labels = input_.to(device), labels.to(device)\n",
" loss, logs = model(input_, labels)\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" if batch_idx % 50 == 0:\n",
" print(f'Batch {batch_idx}, loss = {loss.cpu().detach().item():.4f}')\n",
" for metric_name, value in logs.items():\n",
" print(4*' ' + f'{metric_name} = {value:.4f}')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZhHPfi9bthKp"
},
"source": [
"def get_quadruplet_mask(\n",
" labels: torch.Tensor, # [B]\n",
" device: torch.device\n",
" ):\n",
" \"\"\"Compute a 4d mask for quadruplets (i, j, k, l) such that:\n",
" labels[i] == labels[j] and \n",
" labels[j] != labels[k] and\n",
" labels[k] != labels[l] and\n",
" i != j != k != l\n",
" \"\"\"\n",
" B = labels.size(0)\n",
"\n",
" # Make sure that i != j != k != l\n",
" indices_equal = torch.eye(B, dtype=torch.bool).to(device=device) # [B, B] \n",
" indices_not_equal = ~indices_equal # [B, B] \n",
" i_not_equal_j = indices_not_equal.view(B, B, 1, 1) # [B, B, 1, 1]\n",
" j_not_equal_k = indices_not_equal.view(1, B, B, 1) # [B, 1, 1, B] \n",
" k_not_equal_l = indices_not_equal.view(1, 1, B, B) # [1, 1, B, B] \n",
" distinct_indices = i_not_equal_j & j_not_equal_k & k_not_equal_l # [B, B, B, B] \n",
"\n",
" # Make sure that labels[i] == labels[j] \n",
" # and labels[j] != labels[k] \n",
" # and labels[k] != labels[l]\n",
" labels_equal = labels.view(1, B) == labels.view(B, 1) # [B, B]\n",
" i_equal_j = labels_equal.view(B, B, 1, 1) # [B, B, 1, 1]\n",
" j_equal_k = labels_equal.view(1, B, B, 1) # [1, B, B, 1]\n",
" k_equal_l = labels_equal.view(1, 1, B, B) # [1, 1, B, B]\n",
" \n",
" return (i_equal_j & ~j_equal_k & ~k_equal_l) & distinct_indices # [B, B, B, B] \n",
" \n",
"def test_get_quadruplet_mask(device_for_tests):\n",
" labels = torch.LongTensor([1, 2, 3, 1, 3])\n",
" mask = get_quadruplet_mask(labels, device_for_tests)\n",
" assert mask[0, 3, 1, 2]\n",
" assert mask[2, 4, 0, 1]\n",
" assert mask[4, 2, 1, 0]\n",
" assert not mask[0, 0, 0, 0]\n",
" assert not mask[0, 0, 1, 2]\n",
" assert not mask[0, 3, 4, 4]\n",
" assert not mask[0, 3, 2, 4]\n",
"\n",
"test_get_quadruplet_mask(device_for_tests='cpu')\n"
],
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EuuD6THPtk5l"
},
"source": [
"class QuadrupletLossModel(nn.Module):\n",
" \n",
" def __init__(self, resnet: nn.Module):\n",
" super().__init__()\n",
" self.resnet = resnet\n",
" self.resnet.fc = nn.Identity()\n",
" self.embeddings = nn.Linear(512, 128)\n",
" \n",
" def forward(\n",
" self, \n",
" inputs: torch.Tensor, # [B, C, H, W]\n",
" labels: torch.Tensor # [B]\n",
" ):\n",
" B = labels.size(0)\n",
" embeddings = self.embeddings(self.resnet(inputs)) # [B, E]\n",
" distance_matrix = get_distance_matrix(embeddings) # [B, B]\n",
" with torch.no_grad():\n",
" mask_pos = get_positive_mask(labels, device) # [B, B]\n",
" mask_neg = get_negative_mask(labels, device) # [B, B]\n",
" triplet_mask = get_triplet_mask(labels, device) # [B, B, B]\n",
" quadruplet_mask = get_quadruplet_mask(labels, device) # [B, B, B, B]\n",
" unmasked_triplets = torch.sum(triplet_mask) # [1]\n",
" unmasked_quadruplets = torch.sum(quadruplet_mask) # [1]\n",
" mu_pos = torch.mean(distance_matrix[mask_pos]) # [1]\n",
" mu_neg = torch.mean(distance_matrix[mask_neg]) # [1]\n",
" mu = mu_neg - mu_pos # [1]\n",
" \n",
" distance_i_j = distance_matrix.view(B, B, 1) # [B, B, 1]\n",
" distance_i_k = distance_matrix.view(B, 1, B) # [B, 1, B]\n",
" triplet_loss_unmasked = distance_i_k - distance_i_j # [B, B, B]\n",
" triplet_loss_unmasked = triplet_loss_unmasked[triplet_mask] # [valid_triplets]\n",
" hardest_triplets = triplet_loss_unmasked < max(mu, 0) # [valid_triplets]\n",
" triplet_loss = triplet_loss_unmasked[hardest_triplets] # [valid_triplets_after_mask]\n",
" triplet_loss = nn.functional.relu(triplet_loss) # [valid_triplets_after_mask]\n",
"\n",
" distance_i_j = distance_matrix.view(B, B, 1, 1) # [B, B, 1, 1]\n",
" distance_k_l = distance_matrix.view(1, 1, B, B) # [1, 1, B, B]\n",
" auxilary_loss_unmasked = distance_k_l - distance_i_j # [B, B, B, B]\n",
" auxilary_loss_unmasked = auxilary_loss_unmasked[quadruplet_mask] # [valid_quadruplets]\n",
" hardest_quadruples = auxilary_loss_unmasked < max(mu, 0)/2 # [valid_quadruplets_after_mask]\n",
" auxilary_loss = auxilary_loss_unmasked[hardest_quadruples] # [valid_quadruplets_after_mask]\n",
" auxilary_loss = nn.functional.relu(auxilary_loss) # [valid_triplets_after_mask]\n",
"\n",
" quadruplet_loss = triplet_loss.mean() + auxilary_loss.mean()\n",
" logs = {\n",
" 'positive_pairs': torch.sum(mask_pos).cpu().detach().item(),\n",
" 'negative_pairs': torch.sum(mask_neg).cpu().detach().item(),\n",
" 'mu_neg': mu_neg.cpu().detach().item(),\n",
" 'mu_pos': mu_pos.cpu().detach().item(),\n",
" 'valid_triplets': unmasked_triplets.cpu().detach().item(),\n",
" 'valid_triplets_after_mask': triplet_loss.size(0),\n",
" 'valid_quadruplets': unmasked_quadruplets.cpu().detach().item(),\n",
" 'valid_quadruplets_after_mask': auxilary_loss.size(0),\n",
" 'auxilary_loss': auxilary_loss.mean().cpu().detach().item(),\n",
" 'triplet_loss': triplet_loss.mean().cpu().detach().item()\n",
" }\n",
" return quadruplet_loss, logs\n"
],
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "09HlB0hGuaEi"
},
"source": [
"device = torch.device('cuda')\n",
"resnet18 = models.resnet18(pretrained=False)\n",
"model = QuadrupletLossModel(resnet=resnet18)\n",
"model = model.to(device)\n",
"opt = SGD(model.parameters(), lr=0.001)\n",
"ds_train = CIFAR10('.', transform=to_tensor, download=True)\n",
"dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=True)\n",
"for e in range(10):\n",
" for batch_idx, (input_, labels) in enumerate(dataloader_train):\n",
" input_, labels = input_.to(device), labels.to(device)\n",
" loss, logs = model(input_, labels)\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" if batch_idx % 50 == 0:\n",
" print(f'Batch {batch_idx}, loss = {loss.cpu().detach().item():.4f}')\n",
" for metric_name, value in logs.items():\n",
" print(4*' ' + f'{metric_name} = {value:.4f}')"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment