Skip to content

Instantly share code, notes, and snippets.

@rjurney
Created July 1, 2024 01:03
Show Gist options
  • Save rjurney/b41d57baf32cb8c3373bb060ec0979e0 to your computer and use it in GitHub Desktop.
Save rjurney/b41d57baf32cb8c3373bb060ec0979e0 to your computer and use it in GitHub Desktop.
Cosine similarity adaptation of Sentence-BERT
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
class CosineSentenceBERT(nn.Module):
def __init__(self, model_name=SBERT_MODEL, dim=384):
super().__init__()
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
# Update the FFNN to output embedding dimension
self.ffnn = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(0.1),
)
@staticmethod
def mean_pool(token_embeds, attention_mask):
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
return pool
def encode(self, input_ids, attention_mask):
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
embeddings = self.mean_pool(outputs, attention_mask)
return self.ffnn(embeddings)
def forward(self, input_ids_a, input_ids_b, attention_mask_a=None, attention_mask_b=None, labels=None):
# Encode both sentences
embed_a = self.encode(input_ids_a, attention_mask_a)
embed_b = self.encode(input_ids_b, attention_mask_b)
# Compute cosine similarity
cosine_sim = F.cosine_similarity(embed_a, embed_b)
loss = None
if labels is not None:
loss_fct = nn.CosineEmbeddingLoss()
# CosineEmbeddingLoss expects 1 for similar pairs and -1 for dissimilar pairs
loss = loss_fct(embed_a, embed_b, (labels * 2) - 1)
return {"loss": loss, "similarity": cosine_sim}
def predict(self, a: str, b: str):
encoded_a = self.tokenizer(a, padding=True, truncation=True, return_tensors="pt")
encoded_b = self.tokenizer(b, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embed_a = self.encode(encoded_a["input_ids"].to(self.model.device),
encoded_a["attention_mask"].to(self.model.device))
embed_b = self.encode(encoded_b["input_ids"].to(self.model.device),
encoded_b["attention_mask"].to(self.model.device))
similarity = F.cosine_similarity(embed_a, embed_b).item()
return similarity
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment