Skip to content

Instantly share code, notes, and snippets.

@attila-dusnoki-htec
Created November 28, 2023 16:04
Show Gist options
  • Save attila-dusnoki-htec/83a69e2bb33cf4602a6c501a1980bc64 to your computer and use it in GitHub Desktop.
Save attila-dusnoki-htec/83a69e2bb33cf4602a6c501a1980bc64 to your computer and use it in GitHub Desktop.
whisper_with_attn_mask_onnx
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig
from optimum.exporters.onnx.base import ConfigBehavior
from typing import Dict
class CustomWhisperOnnxConfig(WhisperOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}
if self._behavior is not ConfigBehavior.ENCODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "decoder_sequence_length"}
if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
return common_inputs
@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
if self._behavior is ConfigBehavior.DECODER:
return {
"decoder_input_ids": "decoder_input_ids",
"decoder_attention_mask": "decoder_attention_mask",
"attention_mask": "encoder_attention_mask",
"encoder_outputs": "encoder_hidden_states",
}
return {}
model_id = "openai/whisper-tiny.en"
config = AutoConfig.from_pretrained(model_id)
custom_whisper_onnx_config = CustomWhisperOnnxConfig(
config=config,
task="automatic-speech-recognition",
)
encoder_config = custom_whisper_onnx_config.with_behavior("encoder")
decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False)
custom_onnx_configs={
"encoder_model": encoder_config,
"decoder_model": decoder_config,
}
main_export(
model_id,
output="whisper_with_attn_mask_onnx",
no_post_process=True,
custom_onnx_configs=custom_onnx_configs
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment