Last active
August 2, 2024 12:18
-
-
Save eustlb/d0834ad7e34eed16f2a829f0fb4ff580 to your computer and use it in GitHub Desktop.
Test generation of a parler tts branch for different combinaisons of parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, asdict | |
import torch | |
import soundfile as sf | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer | |
import torch._dynamo.config | |
import torch._inductor.config | |
import numpy as np | |
import time | |
import os | |
import gc | |
from datetime import datetime, timedelta | |
torch._inductor.config.coordinate_descent_tuning = True | |
torch._inductor.config.triton.unique_kernel_names = True | |
torch._inductor.config.fx_graph_cache = True | |
torch._logging.set_logs(graph_breaks=True, recompiles=True) | |
torch.manual_seed(0) | |
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) | |
def prepare_model_inputs( | |
description, | |
prompt, | |
description_tokenizer, | |
prompt_tokenizer, | |
device, | |
max_length_description=50, | |
max_length_prompt=50, | |
pad=False, | |
): | |
pad_args_description = {"padding": "max_length", "max_length": max_length_description} if pad else {} | |
pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {} | |
tokenized_description = description_tokenizer(description, return_tensors="pt", **pad_args_description) | |
input_ids = tokenized_description.input_ids.to(device) | |
attention_mask = tokenized_description.attention_mask.to(device) | |
tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt) | |
prompt_input_ids = tokenized_prompt.input_ids.to(device) | |
prompt_attention_mask = tokenized_prompt.attention_mask.to(device) | |
if pad: | |
model_kwargs = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"prompt_input_ids": prompt_input_ids, | |
"prompt_attention_mask": prompt_attention_mask, | |
} | |
else: | |
model_kwargs = { | |
"input_ids": input_ids, | |
"prompt_input_ids": prompt_input_ids, | |
} | |
return model_kwargs | |
def test_model_generation( | |
model, | |
model_kwargs, | |
): | |
torch.manual_seed(0) | |
generation = model.generate(**model_kwargs).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
return audio_arr | |
@dataclass | |
class TestConfig: | |
model_name: str = None | |
padding_side: str = None | |
attn_implementation: str = None | |
static: bool = None | |
torch_dtype: str = "float32" | |
device: str = "cuda:0" | |
max_length_prompt: int = 20 | |
max_length_description: int = 40 | |
prompt = "Hey, how are you doing today?" | |
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast." | |
@dataclass | |
class Test: | |
def __init__(self, test_config: TestConfig, output_dir_path): | |
self.output_dir_path = output_dir_path | |
self.test_config = test_config | |
self.torch_dtype = getattr(torch, self.test_config.torch_dtype) | |
self.description_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name) | |
if self.test_config.padding_side: | |
self.prompt_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name, padding_side=self.test_config.padding_side) | |
else: | |
self.prompt_tokenizer = AutoTokenizer.from_pretrained(self.test_config.model_name) | |
self.model = ParlerTTSForConditionalGeneration.from_pretrained( | |
self.test_config.model_name, | |
attn_implementation=self.test_config.attn_implementation | |
).to(self.test_config.device, dtype=self.torch_dtype) | |
if self.test_config.static: | |
self.model.generation_config.cache_implementation = "static" | |
def test_generation(self): | |
model_kwargs = prepare_model_inputs( | |
self.test_config.description, | |
self.test_config.prompt, | |
self.description_tokenizer, | |
self.prompt_tokenizer, | |
self.test_config.device, | |
max_length_description=self.test_config.max_length_description, | |
max_length_prompt=self.test_config.max_length_prompt, | |
pad=bool(self.test_config.padding_side) | |
) | |
conf_dict = asdict(self.test_config) | |
for key, value in conf_dict.items(): | |
if value is None: | |
conf_dict[key] = f"no-{key}" | |
conf_list = [str(val) for val in conf_dict.values()] | |
name = f"{'_'.join(conf_list)}.npy".replace("/", "_") | |
saving_path = os.path.join(self.output_dir_path, name) | |
audio_arr = test_model_generation(self.model, model_kwargs) | |
np.save(saving_path, audio_arr) | |
sf.write(saving_path.replace("npy", "wav"), audio_arr, self.model.config.sampling_rate) | |
if __name__ == "__main__": | |
# create output directory | |
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
output_dir_path = os.path.join(CURRENT_DIR, f"test-{current_datetime}") | |
os.makedirs(output_dir_path) | |
bench_start = time.time() | |
# device for all benchmarks | |
device = "cuda:2" | |
# first round: original models without padding -> 12 runs | |
model_names = ["parler-tts/parler_tts_mini_v0.1", "parler-tts/parler-tts-large-v1"] | |
attn_implementations = ["eager", "sdpa"] | |
torch_dtypes = ["float16", "float32", "bfloat16"] | |
for model_name in model_names: | |
model_output_dir_path = os.path.join(output_dir_path, model_name.split("/")[-1]) | |
os.makedirs(model_output_dir_path, exist_ok=True) | |
for attn_implementation in attn_implementations: | |
for torch_dtype in torch_dtypes: | |
test_config = TestConfig( | |
model_name=model_name, | |
device=device, | |
attn_implementation=attn_implementation, | |
torch_dtype=torch_dtype, | |
) | |
bench = Test(test_config, model_output_dir_path) | |
bench.test_generation() | |
del bench | |
gc.collect() | |
torch.cuda.empty_cache() | |
# second round: original models with padding -> 12 runs | |
model_names = [ | |
("parler-tts/parler_tts_mini_v0.1", "left"), | |
("parler-tts/parler-tts-large-v1", "right") | |
] | |
attn_implementations = ["eager", "sdpa"] | |
torch_dtypes = ["float16", "float32", "bfloat16"] | |
for model_name, padding_side in model_names: | |
model_output_dir_path = os.path.join(output_dir_path, model_name.split("/")[-1]) | |
os.makedirs(model_output_dir_path, exist_ok=True) | |
for attn_implementation in attn_implementations: | |
for torch_dtype in torch_dtypes: | |
test_config = TestConfig( | |
model_name=model_name, | |
device=device, | |
attn_implementation=attn_implementation, | |
padding_side=padding_side, | |
torch_dtype=torch_dtype, | |
) | |
bench = Test(test_config, model_output_dir_path) | |
bench.test_generation() | |
del bench | |
gc.collect() | |
torch.cuda.empty_cache() | |
run_time = time.time() - bench_start | |
elapsed_time = timedelta(seconds=run_time) | |
print(f"Benchmark time: {elapsed_time}") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import os | |
import numpy as np | |
def extract_config(path): | |
return "/".join(path.split("/")[-2:]) | |
def main(outputs_dir_1, outputs_dir_2): | |
paths1 = set([extract_config(path) for path in glob.glob(f"{outputs_dir_1}/*/*.npy")]) | |
paths2 = set([extract_config(path) for path in glob.glob(f"{outputs_dir_2}/*/*.npy")]) | |
pairs = paths1 & paths2 | |
print(f"Found {len(pairs)} matches.") | |
for config in pairs: | |
path1 = os.path.join(outputs_dir_1, config) | |
path2 = os.path.join(outputs_dir_2, config) | |
audio_array_1 = np.load(path1) | |
audio_array_2 = np.load(path2) | |
arrays_equal = np.array_equal(audio_array_1, audio_array_2) | |
if not arrays_equal: | |
print(f"{config}: {arrays_equal}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Compare .npy files from two directories.') | |
parser.add_argument('outputs_dir_1', type=str, help='First output directory') | |
parser.add_argument('outputs_dir_2', type=str, help='Second output directory') | |
args = parser.parse_args() | |
main(args.outputs_dir_1, args.outputs_dir_2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Will generate for all configurations of attention and dtypes, with and without padding:
Original models with and without padding will create 24 outputs, compared after to make sure they are the same with add-static-cache.
Since no reference on main branch for outputs generated with static cache, test is done by listening the wavs.
Envs
Usage
cd testmain python test_parler.py
cd teststaticcache python test_parler.py
When finished, compare outputs: