This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def print_transformers_shape_inference(name_or_path: str): | |
"""Prints the transformers shape inference for onnx.""" | |
res = {} | |
model_pipeline = transformers.FeatureExtractionPipeline( | |
model=transformers.AutoModel.from_pretrained(name_or_path), | |
tokenizer=transformers.AutoTokenizer.from_pretrained( | |
name_or_path, use_fast=True | |
), | |
framework="pt", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SentenceTransformer(transformers.BertModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Naming alias for ONNX output specification | |
# Makes it easier to identify the layer | |
self.sentence_embedding = torch.nn.Identity() | |
def forward(self, input_ids, token_type_ids, attention_mask): | |
# Get the token embeddings from the base model | |
token_embeddings = super().forward( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
del model_args["dynamic_axes"]["output_0"] # Delete unused output | |
del model_args["dynamic_axes"]["output_1"] # Delete unused output | |
model_args["dynamic_axes"]["sentence_embedding"] = {0: "batch"} | |
model_args["output_names"] = ["sentence_embedding"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
mkdir universal-sentence-encoder-5 | |
cd universal-sentence-encoder-5 | |
wget https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder-large/5.tar.gz | |
tar -xvzf 5.tar.gz | |
rm 5.tar.gz | |
cd .. | |
python -m tf2onnx.convert --saved-model universal-sentence-encoder-5 --output universal-sentence-encoder-5.onnx --opset 12 --extra_opset ai.onnx.contrib:1 --tag serve |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxruntime import InferenceSession, SessionOptions | |
from onnxruntime_customops import get_library_path | |
opt = rt.SessionOptions() | |
opt.register_custom_ops_library(get_library_path()) | |
sess = rt.InferenceSession("universal-sentence-encoder-5.onnx", opt, providers=ONNX_PROVIDERS) | |
sess.run( | |
output_names=["outputs"], |
OlderNewer