Skip to content

Instantly share code, notes, and snippets.

@peteflorence
Last active August 11, 2022 09:15
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peteflorence/4c009e7dd5eee7b5c8caa2c9bae954d5 to your computer and use it in GitHub Desktop.
Save peteflorence/4c009e7dd5eee7b5c8caa2c9bae954d5 to your computer and use it in GitHub Desktop.
Pixelwise Contrastive Loss in PyTorch
import torch
class PixelwiseContrastiveLoss(torch.nn.Module):
def __init__(self):
super(PixelwiseContrastiveLoss, self).__init__()
self.num_non_matches_per_match = 150
def forward(self, image_a_pred, image_b_pred, matches_a, matches_b, non_matches_a, non_matches_b):
loss = 0
# add loss via matches
matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)
loss += (matches_a_descriptors - matches_b_descriptors).pow(2).sum()
match_loss = 1.0*loss.data[0]
# add loss via non_matches
M_margin = 0.5 # margin parameter
non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
pixel_wise_loss = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=2)
pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
zeros_vec = torch.zeros_like(pixel_wise_loss)
loss += torch.max(zeros_vec, pixel_wise_loss).sum()/self.num_non_matches_per_match
non_match_loss = loss.data[0] - match_loss
return loss, match_loss, non_match_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment