Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active July 5, 2019 02:07
Show Gist options
  • Save bkj/be48c4285d30774224d93c0f5d2cd27b to your computer and use it in GitHub Desktop.
Save bkj/be48c4285d30774224d93c0f5d2cd27b to your computer and use it in GitHub Desktop.
keras lifted loss
#!/usr/bin/env python
"""
keras_lifted_loss.py
"""
from keras import backend as K
def lifted_loss(margin=1):
"""
Lifted loss, per "Deep Metric Learning via Lifted Structured Feature Embedding" by Song et al
Implemented in `keras`
See also the `pytorch` implementation at: https://gist.github.com/bkj/565c5e145786cfd362cffdbd8c089cf4
"""
def f(target, score):
# Compute mask (-1 for different class, 1 for same class, 0 for diagonal)
mask = (2 * K.equal(0, target - K.reshape(target, (-1, 1))) - 1)
mask = (mask - K.eye(score.shape[0]))
# Compute distance between rows
mag = (score ** 2).sum(axis=-1)
mag = K.tile(mag, (mag.shape[0], 1))
dist = (mag + mag.T - 2 * score.dot(score.T))
dist = K.sqrt(K.maximum(0, dist))
# Negative component (points from different class should be far)
l_n = K.sum((K.exp(margin - dist) * K.equal(mask, -1)), axis=-1)
l_n = K.tile(l_n, (score.shape[0], 1))
l_n = K.log(l_n + K.transpose(l_n))
l_n = l_n * K.equal(mask, 1)
# Positive component (points from same class should be close)
l_p = dist * K.equal(mask, 1)
loss = K.sum((K.maximum(0, l_n + l_p) ** 2))
n_pos = K.sum(K.equal(mask, 1))
loss /= (2 * n_pos)
return loss
return f
# --
if __name__ == "__main__":
import numpy as np
np.random.seed(123)
score = np.random.uniform(0, 1, (20, 3))
target = np.random.choice(range(3), 20)
print lifted_loss(1)(target, score).eval()
@dandanmylady
Copy link

dandanmylady commented Jul 5, 2019

@sstolpovskiy
Just realized that it's 1 and half years ago you asked question. I guess you already figure it out.
Lift structured loss needs the 'labels' because it needs computer loss basing on all other sample for each sample, and it needs the label to generate a label matrix( size=(label_size, label_size).
The triplet loss is kind of easy to computer loss, because it only needs to involve (a,p,n) pair. The order of (a,p,n) is a kind of label and it's enough to computer the loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment