Skip to content

Instantly share code, notes, and snippets.

@gautierdag
Last active April 12, 2021 21:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautierdag/cfbebbbc4897dac2f81882e5b64b5b09 to your computer and use it in GitHub Desktop.
Save gautierdag/cfbebbbc4897dac2f81882e5b64b5b09 to your computer and use it in GitHub Desktop.
Pytorch NCE Loss
import torch
import torch.nn as nn
import pytorch_lightning as pl
class NCE(pl.LightningModule):
"""
This implementation is taken from https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py
The mask_correlated_samples funtion has been modified to be much faster to compute
and therefore be able to be called at train time without a predifined batch size.
"""
def __init__(self, temperature=0.1):
super(NCE, self).__init__()
self.temperature = temperature
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool, device=self.device)
mask = mask.fill_diagonal_(0)
mask[:batch_size, batch_size:] = mask[:batch_size, :batch_size]
mask[batch_size:, :batch_size] = mask[:batch_size, :batch_size]
return mask
def forward(self, z_i, z_j):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017),
we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
batch_size = z_i.shape[0]
mask = self.mask_correlated_samples(batch_size)
N = 2 * batch_size
z = torch.cat((z_i, z_j), dim=0)
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
# get similarity between i and j and the reverse as well
sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)
# We have 2N samples, resulting in: 2xN
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
negative_samples = sim[mask].reshape(N, -1)
labels = torch.zeros(N, device=positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= N
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment