Last active
July 23, 2024 16:56
-
-
Save eustlb/63eeb3eae1777cf3f0bc7a06b2623ad5 to your computer and use it in GitHub Desktop.
Reproduce generation error on dev branch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import soundfile as sf | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer | |
import torch | |
torch._logging.set_logs(graph_breaks=True, recompiles=True) | |
torch.manual_seed(0) | |
CUDA_DEVICE = 0 | |
torch_device = f"cuda:{CUDA_DEVICE}" | |
attn_implementation = "eager" | |
model_name = "parler-tts/parler-tts-large-v1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
prompt = "Hey, how are you doing today?" | |
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast." | |
tokenized_description = tokenizer(description, return_tensors="pt") | |
input_ids = tokenized_description.input_ids.to(torch_device) | |
tokenized_prompt = tokenizer(prompt, return_tensors="pt") | |
prompt_input_ids = tokenized_prompt.input_ids.to(torch_device) | |
## 1 | |
torch_dtype = torch.float16 | |
model = ParlerTTSForConditionalGeneration.from_pretrained( | |
model_name, | |
attn_implementation=attn_implementation | |
).to(torch_device, dtype=torch_dtype) | |
generation = model.generate( | |
input_ids=input_ids, | |
prompt_input_ids=prompt_input_ids, | |
).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("./out_1.wav", audio_arr, model.config.sampling_rate) | |
## 2 | |
# if this stay uncommented, the generated audio array will be empty and sf.write will throw an error. | |
# torch.manual_seed(0) | |
torch_dtype = torch.float16 | |
model = ParlerTTSForConditionalGeneration.from_pretrained( | |
model_name, | |
attn_implementation=attn_implementation | |
).to(torch_device, dtype=torch_dtype) | |
generation = model.generate( | |
input_ids=input_ids, | |
prompt_input_ids=prompt_input_ids, | |
).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("./out_2.wav", audio_arr, model.config.sampling_rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment