Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active August 17, 2022 16:38
Show Gist options
  • Save bkj/fd7bfbf0092a4776043a07b68a1625ed to your computer and use it in GitHub Desktop.
Save bkj/fd7bfbf0092a4776043a07b68a1625ed to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
pytorch_dlib_metric_loss.py
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
def dlib_metric_loss(score, target, margin=0.6, extra_margin=0.04):
"""
`pytorch` implementation of `dlib`'s `loss_metric`:
https://github.com/davisking/dlib/blob/0ef3b736fddad7601525278013105239237c42e5/dlib/dnn/loss.h#L932
"""
loss = 0
bsz = score.size(0)
# Compute distance matrix
mag = (score ** 2).sum(1).expand(bsz, bsz)
sim = score.mm(score.transpose(0, 1))
dist = (mag + mag.transpose(0, 1) - 2 * sim)
dist = torch.nn.functional.relu(dist).sqrt()
# Determine number of positive + negative thresh
neg_mask = target.expand(bsz, bsz)
neg_mask = (neg_mask - neg_mask.transpose(0, 1)) != 0
n_pos = (1 - neg_mask).sum() # Number of pairs
n_pos = (n_pos - bsz) / 2 # Number of pairs less diagonal, w/o repetition
n_pos = n_pos.data[0]
neg_thresh = dist[neg_mask].sort()[0][n_pos].data[0]
for r in range(bsz):
x_label = target[r].data[0]
for c in range(bsz):
y_label = target[c].data[0]
d = dist[r,c]
if x_label == y_label:
# Positive examples should be less than (margin - extra_margin)
if d.data[0] > margin - extra_margin:
loss += d - (margin - extra_margin)
else:
# Negative examples should be greater than (margin + extra_margin)
# But... we'll only use the hardest negative pairs
if (d.data[0] < margin + extra_margin) and (d.data[0] < neg_thresh):
loss += (margin + extra_margin) - d
return loss / (2 * n_pos)
@Ashutosh1995
Copy link

Is it the exact implementation of the Deep Metric Learning using Lifted Structured Embedding paper CVPR 2016?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment