Skip to content

Instantly share code, notes, and snippets.

@Raibows
Last active November 30, 2021 02:50
Show Gist options
  • Save Raibows/7fc20ad76b38164fefd0c77795390572 to your computer and use it in GitHub Desktop.
Save Raibows/7fc20ad76b38164fefd0c77795390572 to your computer and use it in GitHub Desktop.
SimCSE loss function pytorch implement
# note this is a copy from https://paste.ubuntu.com/p/Nx5CcSmhHn/ for convenience
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def SimCSE_loss(pred, tau=0.05):
ids = torch.arange(0, pred.shape[0], device=device)
y_true = ids + 1 - ids % 2 * 2
similarities = F.cosine_similarity(pred.unsqueeze(1), pred.unsqueeze(0), dim=2)
# mask h_{i}^{0} with h_{i}^{0}
similarities = similarities - torch.eye(pred.shape[0], device=device) * 1e12
similarities = similarities / tau
return torch.mean(F.cross_entropy(similarities, y_true))
# sentence embedding [A, A, B, B]
pred = torch.tensor([[0.3, 0.2, 2.1, 3.1],
[0.3, 0.2, 2.1, 3.1],
[-1.79, -3, 2.11, 0.89],
[-1.79, -3, 2.11, 0.89]])
SimCSE_loss(pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment