Skip to content

Instantly share code, notes, and snippets.

View monatis's full-sized avatar

M. Yusuf Sarıgöz monatis

View GitHub Profile
@monatis
monatis / d6d57b09-ae18-4976-aaad-903d718c9a75.py
Created March 24, 2022 11:43
triplet_loss_-_advanced_intro3
class BatchAllTtripletLoss(nn.Module):
"""Uses all valid triplets to compute Triplet loss
Args:
margin: Margin value in the Triplet Loss equation
"""
def __init__(self, margin=1.):
super().__init__()
self.margin = margin
@monatis
monatis / cf07badb-e3ba-47c9-bbf5-fbf12d78a558.py
Created March 24, 2022 11:43
triplet_loss_-_advanced_intro2
def get_triplet_mask(labels):
"""compute a mask for valid triplets
Args:
labels: Batch of integer labels. shape: (batch_size,)
Returns:
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
A triplet is valid if:
`labels[i] == labels[j] and labels[i] != labels[k]`
@monatis
monatis / 11dcd0f7-677f-403f-bdd1-7eb3a1571455.py
Created March 24, 2022 11:43
triplet_loss_-_advanced_intro1
def euclidean_distance_matrix(x):
"""Efficient computation of Euclidean distance matrix
Args:
x: Input tensor of shape (batch_size, embedding_dim)
Returns:
Distance matrix of shape (batch_size, batch_size)
"""
# step 1 - compute the dot product
@monatis
monatis / 8fb1a5d4-ef43-4431-809b-048b7d20313d.py
Created March 24, 2022 11:43
triplet_loss_-_advanced_intro0
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks