Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active August 2, 2024 12:18
Show Gist options
  • Save eustlb/d0834ad7e34eed16f2a829f0fb4ff580 to your computer and use it in GitHub Desktop.
Save eustlb/d0834ad7e34eed16f2a829f0fb4ff580 to your computer and use it in GitHub Desktop.
Test generation of a parler tts branch for different combinaisons of parameters.
from dataclasses import dataclass, asdict
import torch
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch._dynamo.config
import torch._inductor.config
import numpy as np
import time
import os
import gc
from datetime import datetime, timedelta
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch.manual_seed(0)
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
def prepare_model_inputs(
description,
prompt,
description_tokenizer,
prompt_tokenizer,
device,
max_length_description=50,
max_length_prompt=50,
pad=False,
):
pad_args_description = {"padding": "max_length", "max_length": max_length_description} if pad else {}
pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
tokenized_description = description_tokenizer(description, return_tensors="pt", **pad_args_description)
input_ids = tokenized_description.input_ids.to(device)
attention_mask = tokenized_description.attention_mask.to(device)
tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt)
prompt_input_ids = tokenized_prompt.input_ids.to(device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(device)
if pad:
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
}
else:
model_kwargs = {
"input_ids": input_ids,
"prompt_input_ids": prompt_input_ids,
}
return model_kwargs
def test_model_generation(
model,
model_kwargs,
):
torch.manual_seed(0)
generation = model.generate(**model_kwargs).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
return audio_arr
@dataclass
class TestConfig:
model_name: str = None
padding_side: str = None
attn_implementation: str = None
static: bool = None
torch_dtype: str = "float32"
device: str = "cuda:0"
max_length_prompt: int = 20
max_length_description: int = 40
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
@dataclass
class Test:
def __init__(self, test_config: TestConfig, output_dir_path):
self.output_dir_path = output_dir_path
self.test_config = test_config
self.torch_dtype = getattr(torch, self.test_config.torch_dtype)
self.description_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name)
if self.test_config.padding_side:
self.prompt_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name, padding_side=self.test_config.padding_side)
else:
self.prompt_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
self.test_config.model_name,
attn_implementation=self.test_config.attn_implementation
).to(self.test_config.device, dtype=self.torch_dtype)
if self.test_config.static:
self.model.generation_config.cache_implementation = "static"
def test_generation(self):
model_kwargs = prepare_model_inputs(
self.test_config.description,
self.test_config.prompt,
self.description_tokenizer,
self.prompt_tokenizer,
self.test_config.device,
max_length_description=self.test_config.max_length_description,
max_length_prompt=self.test_config.max_length_prompt,
pad=bool(self.test_config.padding_side)
)
conf_dict = asdict(self.test_config)
for key, value in conf_dict.items():
if value is None:
conf_dict[key] = f"no-{key}"
conf_list = [str(val) for val in conf_dict.values()]
name = f"{'_'.join(conf_list)}.npy".replace("/", "_")
saving_path = os.path.join(self.output_dir_path, name)
audio_arr = test_model_generation(self.model, model_kwargs)
np.save(saving_path, audio_arr)
sf.write(saving_path.replace("npy", "wav"), audio_arr, self.model.config.sampling_rate)
if __name__ == "__main__":
# create output directory
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir_path = os.path.join(CURRENT_DIR, f"test-{current_datetime}")
os.makedirs(output_dir_path)
bench_start = time.time()
# device for all benchmarks
device = "cuda:2"
# first round: original models without padding -> 12 runs
model_names = ["parler-tts/parler_tts_mini_v0.1", "parler-tts/parler-tts-large-v1"]
attn_implementations = ["eager", "sdpa"]
torch_dtypes = ["float16", "float32", "bfloat16"]
for model_name in model_names:
model_output_dir_path = os.path.join(output_dir_path, model_name.split("/")[-1])
os.makedirs(model_output_dir_path, exist_ok=True)
for attn_implementation in attn_implementations:
for torch_dtype in torch_dtypes:
test_config = TestConfig(
model_name=model_name,
device=device,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
bench = Test(test_config, model_output_dir_path)
bench.test_generation()
del bench
gc.collect()
torch.cuda.empty_cache()
# second round: original models with padding -> 12 runs
model_names = [
("parler-tts/parler_tts_mini_v0.1", "left"),
("parler-tts/parler-tts-large-v1", "right")
]
attn_implementations = ["eager", "sdpa"]
torch_dtypes = ["float16", "float32", "bfloat16"]
for model_name, padding_side in model_names:
model_output_dir_path = os.path.join(output_dir_path, model_name.split("/")[-1])
os.makedirs(model_output_dir_path, exist_ok=True)
for attn_implementation in attn_implementations:
for torch_dtype in torch_dtypes:
test_config = TestConfig(
model_name=model_name,
device=device,
attn_implementation=attn_implementation,
padding_side=padding_side,
torch_dtype=torch_dtype,
)
bench = Test(test_config, model_output_dir_path)
bench.test_generation()
del bench
gc.collect()
torch.cuda.empty_cache()
run_time = time.time() - bench_start
elapsed_time = timedelta(seconds=run_time)
print(f"Benchmark time: {elapsed_time}")
import argparse
import glob
import os
import numpy as np
def extract_config(path):
return "/".join(path.split("/")[-2:])
def main(outputs_dir_1, outputs_dir_2):
paths1 = set([extract_config(path) for path in glob.glob(f"{outputs_dir_1}/*/*.npy")])
paths2 = set([extract_config(path) for path in glob.glob(f"{outputs_dir_2}/*/*.npy")])
pairs = paths1 & paths2
print(f"Found {len(pairs)} matches.")
for config in pairs:
path1 = os.path.join(outputs_dir_1, config)
path2 = os.path.join(outputs_dir_2, config)
audio_array_1 = np.load(path1)
audio_array_2 = np.load(path2)
arrays_equal = np.array_equal(audio_array_1, audio_array_2)
if not arrays_equal:
print(f"{config}: {arrays_equal}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Compare .npy files from two directories.')
parser.add_argument('outputs_dir_1', type=str, help='First output directory')
parser.add_argument('outputs_dir_2', type=str, help='Second output directory')
args = parser.parse_args()
main(args.outputs_dir_1, args.outputs_dir_2)
@eustlb
Copy link
Author

eustlb commented Aug 2, 2024

Will generate for all configurations of attention and dtypes, with and without padding:

  • with the main branch of parler-tts
  • with add-static-cache branch of parler tts

Original models with and without padding will create 24 outputs, compared after to make sure they are the same with add-static-cache.
Since no reference on main branch for outputs generated with static cache, test is done by listening the wavs.

Envs

conda create -n testmain python --yes
conda activate testmain
pip install git+https://github.com/huggingface/parler-tts.git
conda create -n teststaticcache python --yes
conda activate teststaticcache
pip install git+https://github.com/eustlb/parler-tts.git@add-static-cache
git clone https://gist.github.com/d0834ad7e34eed16f2a829f0fb4ff580.git testmain
git clone https://gist.github.com/d0834ad7e34eed16f2a829f0fb4ff580.git teststaticcache
git checkout test-static-cache

Usage

cd testmain
python test_parler.py
cd teststaticcache
python test_parler.py

When finished, compare outputs:

python zcompare_outputs.py outputs_dir_1 outputs_dir_2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment