Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Last active January 8, 2021 10:39
Show Gist options
  • Save seanbenhur/81ba3830d8e2e3fb53653f0c5c7918ed to your computer and use it in GitHub Desktop.
Save seanbenhur/81ba3830d8e2e3fb53653f0c5c7918ed to your computer and use it in GitHub Desktop.
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on:
"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, x0, x1, y):
# euclidian distance
diff = x0 - x1
dist_sq = torch.sum(torch.pow(diff, 2), 1)
dist = torch.sqrt(dist_sq)
mdist = self.margin - dist
dist = torch.clamp(mdist, min=0.0)
loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
loss = torch.sum(loss) / 2.0 / x0.size()[0]
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment