Skip to content

Instantly share code, notes, and snippets.

@harveyslash
Created July 20, 2017 06:25
Show Gist options
  • Save harveyslash/725fcc68df112980328951b3426c0e0b to your computer and use it in GitHub Desktop.
Save harveyslash/725fcc68df112980328951b3426c0e0b to your computer and use it in GitHub Desktop.
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment