Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active July 23, 2024 16:56
Show Gist options
  • Save eustlb/63eeb3eae1777cf3f0bc7a06b2623ad5 to your computer and use it in GitHub Desktop.
Save eustlb/63eeb3eae1777cf3f0bc7a06b2623ad5 to your computer and use it in GitHub Desktop.
Reproduce generation error on dev branch
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch.manual_seed(0)
CUDA_DEVICE = 0
torch_device = f"cuda:{CUDA_DEVICE}"
attn_implementation = "eager"
model_name = "parler-tts/parler-tts-large-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
tokenized_description = tokenizer(description, return_tensors="pt")
input_ids = tokenized_description.input_ids.to(torch_device)
tokenized_prompt = tokenizer(prompt, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
## 1
torch_dtype = torch.float16
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("./out_1.wav", audio_arr, model.config.sampling_rate)
## 2
# if this stay uncommented, the generated audio array will be empty and sf.write will throw an error.
# torch.manual_seed(0)
torch_dtype = torch.float16
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("./out_2.wav", audio_arr, model.config.sampling_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment