Created
May 15, 2023 15:17
-
-
Save MiniXC/319c092871adf711ce85cab691f16404 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import sys | |
from accelerate import Accelerator | |
import torch | |
from tqdm.auto import tqdm | |
from datasets import load_dataset | |
from speech_collator import SpeechCollator | |
from speech_collator.measures import EnergyMeasure, PitchMeasure, SRMRMeasure, SNRMeasure | |
from transformers import HfArgumentParser | |
import wandb | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import evaluate | |
from vocex import Vocex | |
from vocex.utils import NoamLR | |
from training.arguments import Args | |
MEASURE_DICT = { | |
"energy": EnergyMeasure, | |
"pitch": PitchMeasure, | |
"srmr": SRMRMeasure, | |
"snr": SNRMeasure, | |
} | |
mae_metric = evaluate.load("mae") | |
def eval_loop(accelerator, model, eval_ds, step, distributed=False): | |
eval_ds = tqdm(eval_ds, desc="Evaluating", total=len(eval_ds)) | |
model.eval() | |
loss = 0.0 | |
loss_dict = {} | |
i = 0 | |
for batch in eval_ds: | |
outputs = model(**batch, inference=True) | |
if i == 0: | |
# create a lineplot plot for each scalar in the first batch | |
for measure in model.measures: | |
# fig, ax = plt.subplots() | |
pred_vals = outputs["measures"][measure][0] | |
true_vals = batch["measures"][measure][0] | |
if distributed: | |
pred_vals, true_vals = accelerator.gather_for_metrics((pred_vals, true_vals)) | |
pred_vals = pred_vals.detach().cpu().numpy() | |
true_vals = true_vals.detach().cpu().numpy() | |
# sns.lineplot(x=range(len(pred_vals)), y=pred_vals, ax=ax, label="pred") | |
# sns.lineplot(x=range(len(true_vals)), y=true_vals, ax=ax, label="true") | |
# ax.set_title(measure) | |
# # log the figure to wandb | |
# #wandb.log({f"eval/{measure}": wandb.Image(fig)}, step=step) | |
# plt.close(fig) | |
# del fig, ax, pred_vals, true_vals | |
i += 1 | |
loss += outputs["loss"].item() | |
for k, v in outputs["compound_losses"].items(): | |
loss_dict[k] = loss_dict.get(k, 0.0) + v.item() | |
#wandb.log({"eval/loss": loss / len(eval_ds)}, step=step) | |
#wandb.log({f"eval/{k}_loss": v / len(eval_ds) for k, v in loss_dict.items()}, step=step) | |
model.train() | |
def main(): | |
parser = HfArgumentParser([Args]) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): | |
args = parser.parse_yaml_file(sys.argv[1])[0] | |
else: | |
args = parser.parse_args_into_dataclasses()[0] | |
args.measures = args.measures.split(",") | |
wandb.init( | |
name=args.wandb_run_name, | |
project=args.wandb_project, | |
mode=args.wandb_mode, | |
) | |
wandb.config.update(args) | |
libritts = load_dataset(args.dataset) | |
train_ds = libritts[args.train_split] | |
eval_ds = libritts[args.eval_split] | |
if not args.bf16: | |
accelerator = Accelerator() | |
else: | |
accelerator = Accelerator(mixed_precision="bf16") | |
speaker2idx = json.load(open(args.speaker2idx)) | |
phone2idx = json.load(open(args.phone2idx)) | |
collator = SpeechCollator( | |
speaker2idx=speaker2idx, | |
phone2idx=phone2idx, | |
measures=[MEASURE_DICT[measure]() for measure in args.measures], | |
return_keys=[ | |
"mel", | |
"dvector", | |
"measures", | |
], | |
overwrite_max_length=True | |
) | |
model = Vocex( | |
measure_nlayers=args.measure_nlayers, | |
dvector_nlayers=args.dvector_nlayers, | |
depthwise=args.depthwise, | |
noise_factor=args.noise_factor, | |
filter_size=args.filter_size, | |
kernel_size=args.kernel_size, | |
dropout=args.dropout, | |
) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_ds, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
collate_fn=collator.collate_fn, | |
prefetch_factor=args.prefetch_factor, | |
) | |
eval_dataloader = torch.utils.data.DataLoader( | |
eval_ds, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
collate_fn=collator.collate_fn, | |
prefetch_factor=args.prefetch_factor, | |
) | |
model.fit_scalers(train_dataloader, args.fit_scalers_steps) | |
optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=args.learning_rate, | |
weight_decay=args.weight_decay, | |
) | |
num_epochs = args.max_epochs | |
num_training_steps = num_epochs * len(train_dataloader) | |
lr_scheduler = NoamLR( | |
optimizer, | |
warmup_steps=args.warmup_steps, | |
) | |
progress_bar = tqdm(range(num_training_steps)) | |
train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare( | |
train_dataloader, eval_dataloader, model, optimizer | |
) | |
model.train() | |
step = 0 | |
for epoch in range(num_epochs): | |
for batch in train_dataloader: | |
with accelerator.accumulate(model): | |
step += 1 | |
outputs = model(**batch) | |
loss = outputs["loss"] | |
accelerator.backward(loss) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
if step % args.log_every == 0: | |
# get current learning rate set by scheduler using get_last_lr() | |
lr = lr_scheduler.get_last_lr()[0] | |
wandb.log({"train/loss": loss.item(), "lr": lr}, step=step) | |
wandb.log({f"train/{k}": v.item() for k, v in outputs["compound_losses"].items()}, step=step) | |
wandb.log({"train/global_step": step}, step=step) | |
if step % args.eval_every == 0: | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
model.eval() | |
eval_loop(accelerator, model, eval_dataloader, step) | |
model.train() | |
accelerator.wait_for_everyone() | |
if step % args.save_every == 0: | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
torch.save(unwrapped_model.state_dict(), f"{args.checkpoint_dir}/model_{step}.pt") | |
accelerator.wait_for_everyone() | |
progress_bar.update(1) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment