Skip to content

Instantly share code, notes, and snippets.

@JulesBelveze
Created September 3, 2021 12:43
Show Gist options
  • Save JulesBelveze/f9733b17ae3ccc5e8d238a0d4d222fd6 to your computer and use it in GitHub Desktop.
Save JulesBelveze/f9733b17ae3ccc5e8d238a0d4d222fd6 to your computer and use it in GitHub Desktop.
Script to export any `SentenceTransformers` model to ONNX
# Copyright (c) 2021, Hypefactors A/S
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# If one uses 'ConvertibleSentenceTransformer' you might need to change
# the parent class by the right one.
#
# Examples:
# - LaBSE: transformers.BertModel
# - Distiluse: transformers.DistilBertModel
import argparse
from pathlib import Path
from pprint import pprint
import numpy as np
import torch
import transformers
from sentence_transformers import SentenceTransformer, models
from sentence_transformers.models import Pooling, Dense
from transformers import convert_graph_to_onnx
class ConvertibleSentenceTransformer(transformers.DistilBertModel):
"""
This class aims at converting manually a 'SentenceTransformer' model into a 'Transformer' one.
It turned out that directly exporting a 'SentenceTransformer' model to ONNX lead to quite
different embeddings that the ones of the original model.
NOTE: this only works for model using mean pooling.
"""
def __init__(self, config):
super().__init__(config)
# Naming alias for ONNX output specification
# Makes it easier to identify the layer
self.sentence_embedding = torch.nn.Identity()
def forward(self, input_ids, attention_mask, token_type_ids=None):
# Get the token embeddings from the base model
token_embeddings = super().forward(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class ConvertibleSentenceTransformerWithDenseLayer(torch.nn.Module):
"""
This class aims at converting manually a 'SentenceTransformer' model into a 'Transformer' one.
It turned out that directly exporting a 'SentenceTransformer' model to ONNX lead to quite
different embeddings that the ones of the original model.
NOTE: this only works for model using mean pooling followed by a dense layer.
"""
def __init__(self, model_name, init_weight: torch.Tensor = None, init_bias: torch.Tensor = None):
super().__init__()
self.model = models.Transformer(model_name)
self.model.auto_model.config.output_hidden_states = True
self.pool = Pooling(word_embedding_dimension=768, pooling_mode_cls_token=False,
pooling_mode_mean_tokens=True, pooling_mode_max_tokens=False,
pooling_mode_mean_sqrt_len_tokens=False)
self.dense = Dense(in_features=768, out_features=512, bias=True, init_bias=init_bias, init_weight=init_weight,
activation_function=torch.nn.modules.activation.Tanh())
def forward(self, input_ids, attention_mask):
output = self.model({"input_ids": input_ids, "attention_mask": attention_mask})
sentence_embeddings = self.pool(output)["sentence_embedding"]
return self.dense({"sentence_embedding": sentence_embeddings})
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None, required=True)
parser.add_argument("--output-name", type=str, default=None, required=True)
parser.add_argument("--torchscript", type=lambda x: (str(x).lower() == 'true'), default=False)
return parser.parse_args()
def run(args):
model_pipeline = transformers.FeatureExtractionPipeline(
model=transformers.AutoModel.from_pretrained(args["model"]),
tokenizer=transformers.AutoTokenizer.from_pretrained(args["model"], use_fast=True),
framework="pt",
device=-1
)
tokenizer = model_pipeline.tokenizer
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = convert_graph_to_onnx.infer_shapes(
model_pipeline,
"pt"
)
ordered_input_names, model_args = convert_graph_to_onnx.ensure_valid_input(
model_pipeline.model, tokens, input_names
)
del dynamic_axes["output_0"] # Delete unused output
del dynamic_axes["output_1"] # Delete unused output
output_names = ["sentence_embedding"]
dynamic_axes["sentence_embedding"] = {0: 'batch'}
# Check that everything worked
pprint(output_names)
pprint(dynamic_axes)
model_raw = SentenceTransformer(args["model"])
if isinstance(model_raw[-1], Dense):
linear_weights = model_raw[2].linear.weight
linear_biases = model_raw[2].linear.bias
model = ConvertibleSentenceTransformerWithDenseLayer(args["model"], init_weight=linear_weights,
init_bias=linear_biases)
elif isinstance(model_raw[-1], Pooling):
config = model_pipeline.model.config
model = ConvertibleSentenceTransformer(config).from_pretrained(args["model"])
else:
raise NotImplementedError("We don't support such an architecture yet.")
span = "I am a span. A short span, but nonetheless a span"
assert np.allclose(
model_raw.encode(span),
model(**tokenizer(span, return_tensors="pt"))["sentence_embedding"].squeeze().detach().numpy(),
atol=1e-6,
)
outdir = Path(args["output_name"])
output = outdir / f"{args['output_name']}.onnx"
outdir.mkdir(parents=True, exist_ok=True)
if output.exists():
print(f"Model {args['output_name']} exists. Skipping creation")
else:
print(f"Saving to {output}")
# This is essentially a copy of transformers.convert_graph_to_onnx.convert
torch.onnx.export(
model,
model_args,
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=False,
enable_onnx_checker=True,
opset_version=12,
)
if args["torchscript"]:
traced_model = torch.jit.trace(model, model_args)
assert np.allclose(
model_raw.encode(span),
traced_model(**tokenizer(span, return_tensors="pt")).squeeze().detach().numpy(),
atol=1e-6,
)
torch.jit.save(traced_model, f"{outdir}/traced_{args['output_name']}.pt")
if __name__ == "__main__":
args = parse_args()
run(vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment