Last active November 21, 2023 10:31
Hacky PyTorch Batch-Hard Triplet Loss and PK samplers
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import math
def pdist(v):
dist = torch.norm(v[:, None] - v, dim=2, p=2)
return dist
class TripletLoss(nn.Module):
def __init__(self, margin=1.0, sample=True):
super(TripletLoss, self).__init__()
self.margin = margin
self.sample = sample
def forward(self, inputs, targets):
n = inputs.size(0)
# pairwise distances
dist = pdist(inputs)
# find the hardest positive and negative
mask_pos = targets.expand(n, n).eq(targets.expand(n, n).t())
mask_neg = ~mask_pos
mask_pos[torch.eye(n).byte().cuda()] = 0
if self.sample:
# weighted sample pos and negative to avoid outliers causing collapse
posw = (dist + 1e-12) * mask_pos.float()
posi = torch.multinomial(posw, 1)
dist_p = dist.gather(0, posi.view(1, -1))
# There is likely a much better way of sampling negatives in proportion their difficulty, based on distance
# this was a quick hack that ended up working better for some datasets than hard negative
negw = (1 / (dist + 1e-12)) * mask_neg.float()
negi = torch.multinomial(negw, 1)
dist_n = dist.gather(0, negi.view(1, -1))
# hard negative
ninf = torch.ones_like(dist) * float('-inf')
dist_p = torch.max(dist * mask_pos.float(), dim=1)[0]
nindex = torch.max(torch.where(mask_neg, -dist, ninf), dim=1)[1]
dist_n = dist.gather(0, nindex.unsqueeze(0))
# calc loss
diff = dist_p - dist_n
if isinstance(self.margin, str) and self.margin == 'soft':
diff = F.softplus(diff)
diff = torch.clamp(diff + self.margin, min=0.)
loss = diff.mean()
# calculate metrics, no impact on loss
metrics = OrderedDict()
with torch.no_grad():
_, top_idx = torch.topk(dist, k=2, largest=False)
top_idx = top_idx[:, 1:]
flat_idx = top_idx.squeeze() + n * torch.arange(n, out=torch.LongTensor()).cuda()
top1_is_same = torch.take(mask_pos, flat_idx)
metrics['prec'] = top1_is_same.float().mean().item()
metrics['dist_acc'] = (dist_n > dist_p).float().mean().item()
if not isinstance(self.margin, str):
metrics['dist_sm'] = (dist_n > dist_p + self.margin).float().mean().item()
metrics['nonzero_count'] = torch.nonzero(diff).size(0)
metrics['dist_p'] = dist_p.mean().item()
metrics['dist_n'] = dist_n.mean().item()
metrics['rel_dist'] = ((dist_n - dist_p) / torch.max(dist_p, dist_n)).mean().item()
return loss, metrics
import torch
from import Sampler
import numpy as np
# Both samplers are passed a data_source (likely your dataset) that has following members:
# * label_to_samples - mapping of label ids (zero based integer) to samples for that label
class PKSampler(Sampler):
def __init__(self, data_source, p=64, k=16):
self.p = p
self.k = k
self.data_source = data_source
def __iter__(self):
pk_count = len(self) // (self.p * self.k)
for _ in range(pk_count):
labels = np.random.choice(
np.arange(len(self.data_source.label_to_samples.keys()), self.p, replace=False)
for l in labels:
indices = self.data_source.label_to_samples[l]
replace = True if len(indices) < self.k else False
for i in np.random.choice(indices, self.k, replace=replace):
yield i
def __len__(self):
pk = self.p * self.k
samples = ((len(self.data_source) - 1) // pk + 1) * pk
return samples
def grouper(iterable, n):
it = itertools.cycle(iter(iterable))
for _ in range((len(iterable) - 1) // n + 1):
yield list(itertools.islice(it, n))
# full label coverage per 'epoch'
class PKSampler2(Sampler):
def __init__(self, data_source, p=64, k=16):
self.p = p
self.k = k
self.data_source = data_source
def __iter__(self):
rand_labels = np.random.permutation(
for labels in grouper(rand_labels, self.p):
for l in labels:
indices = self.data_source.label_to_samples[l]
replace = True if len(indices) < self.k else False
for j in np.random.choice(indices, self.k, replace=replace):
yield j
def __len__(self):
num_labels = len(self.data_source.label_names)
samples = ((num_labels - 1) // self.p + 1) * self.p * self.k
return samples
Hello ! Thanks a lot.
How would you handle the semi hard positive in the triplet loss ? I tried to get my head around the indexes but got lost ...

