Skip to content

Instantly share code, notes, and snippets.

@Fhrozen
Last active April 8, 2024 10:04
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
Routine to generate an ONNX model for ESPnet 2 - Text2Speech model
#!/usr/bin/env python3
"""Convert TTS to ONNX
Using ESPnet.
Test command:
python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits
"""
import argparse
import logging
import sys
import numpy as np
import torch
import time
from typing import Dict
from typing import Optional
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
import torch.nn.functional as F
def get_parser():
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--tts-tag",
required=True,
type=str,
help="TTS tag (or Directory) for model located at huggingface/zenodo/local"
)
return parser
### Add this at espnet2/gan_tts/vits/vits.py
def inference_onnx(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference for ONNX.
"""
if sids is not None:
sids = sids.view(1)
if lids is not None:
lids = lids.view(1)
if durations is not None:
durations = durations.view(1, 1, -1)
# inference
if use_teacher_forcing:
raise NotImplementedError
else:
wav, _, _ = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav.view(-1)
def test_onnx():
logging.info('Test ONNX')
import onnxruntime as ort
this_text = 'Hello world, how are you doing'
this_text = preprocessing("<dummy>", dict(text=this_text))['text']
this_text = this_text[None]
# this_len = np.array([this_text.shape[1]], dtype=int)
ort_sess = ort.InferenceSession('tts_model.onnx')
inname = [input.name for input in ort_sess.get_inputs()]
outname = [output.name for output in ort_sess.get_outputs()]
logging.info("inputs name: %s || outputs name: %s", inname, outname)
outputs = ort_sess.run(None, {'input_text': this_text})
logging.info(type(outputs))
if __name__ == "__main__":
# Logger
parser = get_parser()
args = parser.parse_args()
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt)
# Load Pretrained model and testing wav generation
logging.info("Preparing pretrained model from: %s", args.tts_tag)
text2speech = Text2Speech.from_pretrained(
model_tag=str_or_none(args.tts_tag),
vocoder_tag=None,
device="cuda",
# Only for Tacotron 2 & Transformer
threshold=0.5,
# Only for Tacotron 2
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=False,
backward_window=1,
forward_window=3,
# Only for FastSpeech & FastSpeech2 & VITS
speed_control_alpha=1.0,
# Only for VITS
noise_scale=0.667,
noise_scale_dur=0.8,
)
text = 'Hello world'
logging.info("Generating test wav using the sequence: %s", text)
with torch.no_grad():
start = time.time()
wav = text2speech(text)["wav"]
rtf = (time.time() - start) / (len(wav) / text2speech.fs)
logging.info(f"RTF = {rtf:5f}")
# Prepare modules for conversion
logging.info("Generate ONNX models")
with torch.no_grad():
device = text2speech.device
preprocessing = text2speech.preprocess_fn
model_tts = text2speech.tts
# Replace forward with inference to avoid problems at ONNX generation
model_tts.forward = model_tts.inference_onnx
# Preprocessing data
preproc_text = preprocessing("<dummy>", dict(text=text))['text']
preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0)
text_lengths = torch.tensor(
[preproc_text.size(1)],
dtype=torch.long,
device=preproc_text.device,
)
wav = model_tts(preproc_text, text_lengths)
logging.info(wav.shape)
inputs = (preproc_text, text_lengths)
# Generate TTS Model
torch.onnx.export(
model_tts,
inputs,
'tts_model.onnx',
export_params=True,
opset_version=13,
do_constant_folding=True,
verbose=True,
input_names=['input_text'],
output_names=['wav'],
dynamic_axes={
'input_text': {
1: 'length'
},
'wav': {
0: 'length'
}
}
)
test_onnx()
sys.exit(0)
@Fhrozen
Copy link
Author

Fhrozen commented Nov 30, 2021

A command for test:

python convert_onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits

The model is a single speaker VITS model. Not tested on other models.

Installation:

conda install pytorch=1.10.0 cudatoolkit=11.1 -c pytorch -c nvidia
pip install espnet_model_zoo
pip install onnxruntime-gpu  # for GPU
pip install onnxruntime  # for CPU use

Issues:

Need to check some internal processes because it can generate the model, but cannot be used.
I suppose that the problem is that some variables are changed to python lists or dicts.

@sciai-ai
Copy link

@Fhrozen I am getting the same error with opset_version=11/12/13

RuntimeError: Exporting the operator _thnn_fused_lstm_cell to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

@Fhrozen
Copy link
Author

Fhrozen commented Nov 30, 2021

Could you tell me, which pretrained model are you using?

@sciai-ai
Copy link

I am using tacotron2 which I believe uses LSTM

@Fhrozen
Copy link
Author

Fhrozen commented Nov 30, 2021

@sciai-ai
Copy link

sciai-ai commented Dec 2, 2021

@Fhrozen Thanks for sharing these, i googled it too and found the same threads. It does not seem very straight forward to do this :(

If your ONXX conversion works with joint model, then I might retrain my data on it.

@Fhrozen
Copy link
Author

Fhrozen commented Dec 2, 2021

This code is for the joint (Transformer/FastSpeech + ParallelWGan) but it still has the problems of the python list/dicts that I will check later. Probably for VITS is easier (I did not try so I cannot confirm).

@sciai-ai
Copy link

sciai-ai commented Dec 7, 2021

@Fhrozen was wondering if you manage to get it working. Thanks

@Fhrozen
Copy link
Author

Fhrozen commented Dec 7, 2021

@sciai-ai, I think it is about 50% on the way. It needs to change so many parts because there are variables that are changed from torch -> onnx, and these changes generate constants that later generate errors.
If you can, test the updated code and see if you may be able to fix any additional issue.
The current warnings I am having are:

/export/db/espnet/converter/espnet/nets/pytorch_backend/nets_utils.py:154: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  lengths = lengths.tolist()
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/embedding.py:198: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.pe.size(1) >= x.size(1) * 2 - 1:
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/embedding.py:239: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:256: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  :, :, :, : x.size(-1) // 2 + 1
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:81: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:81: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
/export/db/espnet/converter/espnet2/gan_tts/vits/text_encoder.py:139: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  m, logs = stats.split(stats.size(1) // 2, dim=1)
/export/db/espnet/converter/espnet2/gan_tts/vits/flow.py:285: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  xa, xb = x.split(x.size(1) // 2, 1)
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:118: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if torch.min(inputs) < left or torch.max(inputs) > right:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:123: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if min_bin_width * num_bins > 1.0:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:125: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if min_bin_height * num_bins > 1.0:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:175: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert (discriminant >= 0).all()
/export/db/espnet/converter/espnet2/gan_tts/vits/residual_coupling.py:211: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  xa, xb = x.split(x.size(1) // 2, dim=1)
/export/db/espnet/converter/espnet2/gan_tts/wavenet/residual_block.py:140: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)

When you run the test_onnx method, you will have an error:

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19

I will try to check at weekend

@sciai-ai
Copy link

sciai-ai commented Dec 7, 2021

where is the test_onxx method?

@Fhrozen
Copy link
Author

Fhrozen commented Dec 7, 2021

L83-L97

@sciai-ai
Copy link

sciai-ai commented Dec 8, 2021

@Fhrozen I am getting this error:


AttributeError Traceback (most recent call last)
/tmp/ipykernel_1834/548140431.py in
5
6 # Replace forward with inference to avoid problems at ONNX generation
----> 7 model_tts.forward = model_tts.inference_onnx
8
9 # Preprocessing data

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
1176 return modules[name]
1177 raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1178 type(self).name, name))
1179
1180 def setattr(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'VITS' object has no attribute 'inference_onnx'

@sciai-ai
Copy link

Hi @Fhrozen did you get a chance to look into it?

@Fhrozen
Copy link
Author

Fhrozen commented Dec 15, 2021

@sciai-ai, Sorry. I need a little longer time. I expect to implement the fixes for this or at last next weekend.

@sciai-ai
Copy link

Hi @Fhrozen did you get a chance to work on it?

@Fhrozen
Copy link
Author

Fhrozen commented Dec 31, 2021

Sorry for the late response. I am checking for solutions . I will update u once finished

@martin3252
Copy link

@Fhrozen Any updates?

@Fhrozen
Copy link
Author

Fhrozen commented Mar 14, 2022

@unparalleled-ysj
Copy link

@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment