Created
April 3, 2021 13:30
-
-
Save oborchers/369b89fa831f27ea8d6b7a76ebf08aee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SentenceTransformer(transformers.BertModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Naming alias for ONNX output specification | |
# Makes it easier to identify the layer | |
self.sentence_embedding = torch.nn.Identity() | |
def forward(self, input_ids, token_type_ids, attention_mask): | |
# Get the token embeddings from the base model | |
token_embeddings = super().forward( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids | |
)[0] | |
# Stack the pooling layer on top of it | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()) | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return self.sentence_embedding(sum_embeddings / sum_mask) | |
# Create the new model based on the config of the original pipeline | |
model = SentenceTransformer(config=nlp.model.config).from_pretrained(model_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment