Skip to content

Instantly share code, notes, and snippets.

@jobergum
Created April 23, 2021 18:50
Show Gist options
  • Save jobergum/6797a4421596b6c1a2fba76a50cdff64 to your computer and use it in GitHub Desktop.
Save jobergum/6797a4421596b6c1a2fba76a50cdff64 to your computer and use it in GitHub Desktop.
import torch
from transformers import BertPreTrainedModel
from transformers import BertModel
class SentenceEncoder(BertPreTrainedModel):
def __init__(self,config):
super().__init__(config)
self.bert = BertModel(config)
self.init_weights()
def forward(self, input_ids, attention_mask):
model_output = self.bert(input_ids,attention_mask=attention_mask)
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
s = SentenceEncoder.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")
input_names = ["input_ids", "attention_mask"]
output_names = ["contextual"]
#input, max 32 query term
input_ids = torch.ones(1,32, dtype=torch.int64)
attention_mask = torch.ones(1,32,dtype=torch.int64)
args = (input_ids, attention_mask)
torch.onnx.export(s,
args=args,
f="sentence_mean_encoder.onnx",
input_names = input_names,
output_names = output_names,
dynamic_axes = {
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
"contextual": {0: "batch"},
},opset_version=11)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment