Created
June 24, 2024 21:19
-
-
Save rjurney/29cd493a4649423de2f470c1b235c84e to your computer and use it in GitHub Desktop.
An address matching SentenceBERT class Claude helped me write
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SentenceBERT(torch.nn.Module): | |
def __init__(self, model_name=SBERT_MODEL, dim=384): | |
super().__init__() | |
self.model_name = model_name | |
self.tokenizer = AutoTokenizer.from_pretrained("data/fine-tuned-sbert-paraphrase-multilingual-MiniLM-L12-v2-original/checkpoint-2400/") | |
self.model = AutoModel.from_pretrained("data/fine-tuned-sbert-paraphrase-multilingual-MiniLM-L12-v2-original/checkpoint-2400/") | |
self.ffnn = torch.nn.Linear(dim*3, 1) | |
# Freeze the weights of the pre-trained model | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
@staticmethod | |
def mean_pool(token_embeds, attention_mask): | |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float() | |
pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9) | |
return pool | |
def _check_similarity(self, a, b, mask_a, mask_b): | |
with torch.no_grad(): # Ensure no gradients are computed for the frozen model | |
u = self.model(a, attention_mask=mask_a)[0] | |
v = self.model(b, attention_mask=mask_b)[0] | |
u = SentenceBERT.mean_pool(u, mask_a) | |
v = SentenceBERT.mean_pool(v, mask_b) | |
uv = torch.abs(u - v) | |
x = torch.cat([u, v, uv], dim=-1) | |
x = self.ffnn(x).float() | |
return x | |
def check_similarity(self, a, b): | |
encoded_a = self.tokenizer(a, padding=True, truncation=True, return_tensors="pt") | |
encoded_b = self.tokenizer(b, padding=True, truncation=True, return_tensors="pt") | |
a = encoded_a["input_ids"] | |
b = encoded_b["input_ids"] | |
mask_a = encoded_a["attention_mask"] | |
mask_b = encoded_b["attention_mask"] | |
with torch.no_grad(): | |
return self._check_similarity(a, b, mask_a, mask_b) | |
def forward(self, input_ids_a, input_ids_b, attention_mask_a=None, attention_mask_b=None, labels=None): | |
logits = self._check_similarity(input_ids_a, input_ids_b, attention_mask_a, attention_mask_b) | |
loss = None | |
if labels is not None: | |
loss_fct = torch.nn.BCEWithLogitsLoss() | |
loss = loss_fct(logits.view(-1), labels.float().view(-1)) | |
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} | |
def predict(self, a: str, b: str): | |
with torch.no_grad(): | |
logits = self.check_similarity(a, b) | |
probabilities = torch.sigmoid(logits) | |
predicted_class = (probabilities > 0.5).long().item() | |
return predicted_class, probabilities.item() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment