Skip to content

Instantly share code, notes, and snippets.

@matpalm
Created July 25, 2020 08:34
Show Gist options
  • Save matpalm/753b08c19f37d9c4234cd70b88e3c8bb to your computer and use it in GitHub Desktop.
Save matpalm/753b08c19f37d9c4234cd70b88e3c8bb to your computer and use it in GitHub Desktop.
def triplet_loss(anchor_embeddings,
positive_embeddings,
negative_embeddings,
margin=0.0):
dist_a_p = tf.norm(anchor_embeddings - positive_embeddings, axis=1) # (B)
dist_a_n = tf.norm(anchor_embeddings - negative_embeddings, axis=1) # (B)
constraint = dist_a_p - dist_a_n + margin # (B)
per_element_hinge_loss = tf.maximum(0.0, constraint) # (B)
return tf.reduce_mean(per_element_hinge_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment