-
-
Save attila-dusnoki-htec/83a69e2bb33cf4602a6c501a1980bc64 to your computer and use it in GitHub Desktop.
whisper_with_attn_mask_onnx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from optimum.exporters.onnx import main_export | |
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig | |
from transformers import AutoConfig | |
from optimum.exporters.onnx.base import ConfigBehavior | |
from typing import Dict | |
class CustomWhisperOnnxConfig(WhisperOnnxConfig): | |
@property | |
def inputs(self) -> Dict[str, Dict[int, str]]: | |
common_inputs = {} | |
if self._behavior is not ConfigBehavior.DECODER: | |
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} | |
if self._behavior is not ConfigBehavior.ENCODER: | |
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} | |
common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "decoder_sequence_length"} | |
if self._behavior is ConfigBehavior.DECODER: | |
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} | |
return common_inputs | |
@property | |
def torch_to_onnx_input_map(self) -> Dict[str, str]: | |
if self._behavior is ConfigBehavior.DECODER: | |
return { | |
"decoder_input_ids": "decoder_input_ids", | |
"decoder_attention_mask": "decoder_attention_mask", | |
"attention_mask": "encoder_attention_mask", | |
"encoder_outputs": "encoder_hidden_states", | |
} | |
return {} | |
model_id = "openai/whisper-tiny.en" | |
config = AutoConfig.from_pretrained(model_id) | |
custom_whisper_onnx_config = CustomWhisperOnnxConfig( | |
config=config, | |
task="automatic-speech-recognition", | |
) | |
encoder_config = custom_whisper_onnx_config.with_behavior("encoder") | |
decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False) | |
custom_onnx_configs={ | |
"encoder_model": encoder_config, | |
"decoder_model": decoder_config, | |
} | |
main_export( | |
model_id, | |
output="whisper_with_attn_mask_onnx", | |
no_post_process=True, | |
custom_onnx_configs=custom_onnx_configs | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment