Skip to content

Instantly share code, notes, and snippets.

@Fhrozen
Last active April 8, 2024 10:04
Show Gist options
  • Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
Routine to generate an ONNX model for ESPnet 2 - Text2Speech model
#!/usr/bin/env python3
"""Convert TTS to ONNX
Using ESPnet.
Test command:
python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits
"""
import argparse
import logging
import sys
import numpy as np
import torch
import time
from typing import Dict
from typing import Optional
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
import torch.nn.functional as F
def get_parser():
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--tts-tag",
required=True,
type=str,
help="TTS tag (or Directory) for model located at huggingface/zenodo/local"
)
return parser
### Add this at espnet2/gan_tts/vits/vits.py
def inference_onnx(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference for ONNX.
"""
if sids is not None:
sids = sids.view(1)
if lids is not None:
lids = lids.view(1)
if durations is not None:
durations = durations.view(1, 1, -1)
# inference
if use_teacher_forcing:
raise NotImplementedError
else:
wav, _, _ = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav.view(-1)
def test_onnx():
logging.info('Test ONNX')
import onnxruntime as ort
this_text = 'Hello world, how are you doing'
this_text = preprocessing("<dummy>", dict(text=this_text))['text']
this_text = this_text[None]
# this_len = np.array([this_text.shape[1]], dtype=int)
ort_sess = ort.InferenceSession('tts_model.onnx')
inname = [input.name for input in ort_sess.get_inputs()]
outname = [output.name for output in ort_sess.get_outputs()]
logging.info("inputs name: %s || outputs name: %s", inname, outname)
outputs = ort_sess.run(None, {'input_text': this_text})
logging.info(type(outputs))
if __name__ == "__main__":
# Logger
parser = get_parser()
args = parser.parse_args()
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt)
# Load Pretrained model and testing wav generation
logging.info("Preparing pretrained model from: %s", args.tts_tag)
text2speech = Text2Speech.from_pretrained(
model_tag=str_or_none(args.tts_tag),
vocoder_tag=None,
device="cuda",
# Only for Tacotron 2 & Transformer
threshold=0.5,
# Only for Tacotron 2
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=False,
backward_window=1,
forward_window=3,
# Only for FastSpeech & FastSpeech2 & VITS
speed_control_alpha=1.0,
# Only for VITS
noise_scale=0.667,
noise_scale_dur=0.8,
)
text = 'Hello world'
logging.info("Generating test wav using the sequence: %s", text)
with torch.no_grad():
start = time.time()
wav = text2speech(text)["wav"]
rtf = (time.time() - start) / (len(wav) / text2speech.fs)
logging.info(f"RTF = {rtf:5f}")
# Prepare modules for conversion
logging.info("Generate ONNX models")
with torch.no_grad():
device = text2speech.device
preprocessing = text2speech.preprocess_fn
model_tts = text2speech.tts
# Replace forward with inference to avoid problems at ONNX generation
model_tts.forward = model_tts.inference_onnx
# Preprocessing data
preproc_text = preprocessing("<dummy>", dict(text=text))['text']
preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0)
text_lengths = torch.tensor(
[preproc_text.size(1)],
dtype=torch.long,
device=preproc_text.device,
)
wav = model_tts(preproc_text, text_lengths)
logging.info(wav.shape)
inputs = (preproc_text, text_lengths)
# Generate TTS Model
torch.onnx.export(
model_tts,
inputs,
'tts_model.onnx',
export_params=True,
opset_version=13,
do_constant_folding=True,
verbose=True,
input_names=['input_text'],
output_names=['wav'],
dynamic_axes={
'input_text': {
1: 'length'
},
'wav': {
0: 'length'
}
}
)
test_onnx()
sys.exit(0)
@Fhrozen
Copy link
Author

Fhrozen commented Dec 31, 2021

Sorry for the late response. I am checking for solutions . I will update u once finished

@martin3252
Copy link

@Fhrozen Any updates?

@Fhrozen
Copy link
Author

Fhrozen commented Mar 14, 2022

@unparalleled-ysj
Copy link

@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment