Skip to content

Instantly share code, notes, and snippets.

@oborchers
Created April 3, 2021 13:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oborchers/369b89fa831f27ea8d6b7a76ebf08aee to your computer and use it in GitHub Desktop.
Save oborchers/369b89fa831f27ea8d6b7a76ebf08aee to your computer and use it in GitHub Desktop.
class SentenceTransformer(transformers.BertModel):
def __init__(self, config):
super().__init__(config)
# Naming alias for ONNX output specification
# Makes it easier to identify the layer
self.sentence_embedding = torch.nn.Identity()
def forward(self, input_ids, token_type_ids, attention_mask):
# Get the token embeddings from the base model
token_embeddings = super().forward(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)[0]
# Stack the pooling layer on top of it
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return self.sentence_embedding(sum_embeddings / sum_mask)
# Create the new model based on the config of the original pipeline
model = SentenceTransformer(config=nlp.model.config).from_pretrained(model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment