Skip to content

Instantly share code, notes, and snippets.

@rjurney
Created June 24, 2024 21:19
Show Gist options
  • Save rjurney/29cd493a4649423de2f470c1b235c84e to your computer and use it in GitHub Desktop.
Save rjurney/29cd493a4649423de2f470c1b235c84e to your computer and use it in GitHub Desktop.
An address matching SentenceBERT class Claude helped me write
class SentenceBERT(torch.nn.Module):
def __init__(self, model_name=SBERT_MODEL, dim=384):
super().__init__()
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained("data/fine-tuned-sbert-paraphrase-multilingual-MiniLM-L12-v2-original/checkpoint-2400/")
self.model = AutoModel.from_pretrained("data/fine-tuned-sbert-paraphrase-multilingual-MiniLM-L12-v2-original/checkpoint-2400/")
self.ffnn = torch.nn.Linear(dim*3, 1)
# Freeze the weights of the pre-trained model
for param in self.model.parameters():
param.requires_grad = False
@staticmethod
def mean_pool(token_embeds, attention_mask):
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
return pool
def _check_similarity(self, a, b, mask_a, mask_b):
with torch.no_grad(): # Ensure no gradients are computed for the frozen model
u = self.model(a, attention_mask=mask_a)[0]
v = self.model(b, attention_mask=mask_b)[0]
u = SentenceBERT.mean_pool(u, mask_a)
v = SentenceBERT.mean_pool(v, mask_b)
uv = torch.abs(u - v)
x = torch.cat([u, v, uv], dim=-1)
x = self.ffnn(x).float()
return x
def check_similarity(self, a, b):
encoded_a = self.tokenizer(a, padding=True, truncation=True, return_tensors="pt")
encoded_b = self.tokenizer(b, padding=True, truncation=True, return_tensors="pt")
a = encoded_a["input_ids"]
b = encoded_b["input_ids"]
mask_a = encoded_a["attention_mask"]
mask_b = encoded_b["attention_mask"]
with torch.no_grad():
return self._check_similarity(a, b, mask_a, mask_b)
def forward(self, input_ids_a, input_ids_b, attention_mask_a=None, attention_mask_b=None, labels=None):
logits = self._check_similarity(input_ids_a, input_ids_b, attention_mask_a, attention_mask_b)
loss = None
if labels is not None:
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1), labels.float().view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def predict(self, a: str, b: str):
with torch.no_grad():
logits = self.check_similarity(a, b)
probabilities = torch.sigmoid(logits)
predicted_class = (probabilities > 0.5).long().item()
return predicted_class, probabilities.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment