Skip to content

Instantly share code, notes, and snippets.

@skirdey
Created November 15, 2024 09:53
Show Gist options
  • Save skirdey/4c90202ee4aa753a0184f4366953b60a to your computer and use it in GitHub Desktop.
Save skirdey/4c90202ee4aa753a0184f4366953b60a to your computer and use it in GitHub Desktop.
voicerestore training code
import os
import argparse
import torch
from datasets import get_dataset_config_names, load_dataset
from schedulefree import AdamWScheduleFree
from torch.utils.data import ConcatDataset
from trainer import VoiceEnhancementDataset, VoiceRestoreTrainer
from voice_restore import VoiceRestore
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("CUDA Version:", torch.version.cuda)
print("PyTorch CUDA Runtime Version:", torch._C._cuda_getCompiledVersion())
print("Device Name:", torch.cuda.get_device_name(0))
def human_readable_number(num):
for unit in ["", "thousand", "million", "billion", "trillion"]:
if abs(num) < 1000.0:
return f"{num:3.1f} {unit}"
num /= 1000.0
return f"{num:.1f} quadrillion"
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
# Define a function to load and split datasets
def load_and_split_dataset(
dataset_name,
config_name=None,
split="train",
cache_dir=None,
split_ratio=0.00001,
token=None,
):
if token:
dataset = load_dataset(
dataset_name,
config_name,
split=split,
cache_dir=cache_dir,
num_proc=4,
save_infos=True,
token=token,
trust_remote_code=True,
verification_mode="no_checks",
).train_test_split(test_size=split_ratio)
else:
dataset = load_dataset(
dataset_name,
config_name,
split=split,
cache_dir=cache_dir,
num_proc=4,
save_infos=True,
trust_remote_code=True,
verification_mode="no_checks",
).train_test_split(test_size=split_ratio)
train_dataset = VoiceEnhancementDataset(dataset["train"])
val_dataset = VoiceEnhancementDataset(dataset["test"])
return train_dataset, val_dataset
if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
print("Starting main execution")
parser = argparse.ArgumentParser(description="Voice Restore Training Script")
parser.add_argument(
"--quick_run",
action="store_true",
help="Run with a single small dataset for quick testing",
)
args = parser.parse_args()
train_dataset = None
if args.quick_run:
print("Running in quick mode with a single dataset")
expresso_train, _ = load_and_split_dataset("ylacombe/expresso", split="train")
train_dataset = expresso_train
grad_accumulation_steps = 1
epochs = 1000
batch_size = 2
# Start training
try:
print("Initializing trainer")
voice_restore = VoiceRestore(
num_channels=100,
transformer=dict(
dim=768,
depth=20,
skip_connect_type="concat",
heads=16,
dim_head=64,
max_seq_len=2000,
),
sigma=0.0,
)
total_params, trainable_params = count_parameters(voice_restore)
print(f"Total parameters: {human_readable_number(total_params)}")
print(f"Trainable parameters: {human_readable_number(trainable_params)}")
optimizer = AdamWScheduleFree(
voice_restore.parameters(), lr=3e-5, warmup_steps=5000
)
trainer = VoiceRestoreTrainer(
voice_restore,
optimizer,
grad_accumulation_steps=grad_accumulation_steps,
checkpoint_path="./checkpoints/voicerestore_9_28_24.pt",
log_file="./logs.txt",
accelerate_kwargs=dict(mixed_precision="bf16"),
)
print("Trainer initialized successfully")
print("Starting training")
trainer.train(train_dataset, epochs, batch_size, save_step=5_000)
print("Training completed successfully")
except Exception as e:
print(f"An error occurred: {str(e)}")
import traceback
traceback.print_exc()
---
from __future__ import annotations
import os
import random
from types import SimpleNamespace
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.checkpoint
import torchaudio
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from einops import rearrange
from ema_pytorch import EMA
from loguru import logger
from pedalboard import (
Bitcrush,
Chorus,
Compressor,
Distortion,
GSMFullRateCompressor,
HighpassFilter,
LowpassFilter,
MP3Compressor,
Pedalboard,
Resample,
Reverb,
)
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchaudio.transforms import FrequencyMasking, TimeMasking
from tqdm import tqdm
from ambient_noise import (
load_ambient_noise_data,
mix_speech_with_ambient_noise,
)
from gen_noise import generate_acoustic_noise
from meldataset import get_mel_spectrogram
config_h = {
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_tanh_at_final": False,
"use_bias_at_final": False,
"activation": "snakebeta",
"snake_logscale": True,
"use_cqtd_instead_of_mrd": True,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": False,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": True,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"fmax_for_loss": None,
"normalize_volume": True,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
config_h = SimpleNamespace(**config_h)
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 4))
im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
plt.close()
return fig
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
EFFECTS = {
"mp3lame": [
MP3Compressor(vbr_quality=5.0),
MP3Compressor(vbr_quality=6.0),
MP3Compressor(vbr_quality=7.0),
MP3Compressor(vbr_quality=8.0),
MP3Compressor(vbr_quality=9.0),
],
"reverb": [
Reverb(
room_size=0.5,
damping=0.5,
wet_level=0.33,
dry_level=0.4,
width=1.0,
freeze_mode=0.0,
),
Reverb(
room_size=0.2,
damping=0.9,
wet_level=0.66,
dry_level=0.2,
width=0.5,
freeze_mode=0.0,
),
Reverb(
room_size=0.8,
damping=0.2,
wet_level=0.6,
dry_level=0.4,
width=1.0,
freeze_mode=0.0,
),
Reverb(room_size=0.1),
Reverb(room_size=0.2),
Reverb(room_size=0.3),
Reverb(room_size=0.4),
Reverb(room_size=0.5),
],
"bitcrush": [
Bitcrush(bit_depth=4.0),
Bitcrush(bit_depth=8.0),
Bitcrush(bit_depth=16.0),
],
"compressor": [
Compressor(threshold_db=-20, ratio=3),
GSMFullRateCompressor(),
GSMFullRateCompressor(quality=Resample.Quality(0)),
GSMFullRateCompressor(quality=Resample.Quality(1)),
GSMFullRateCompressor(quality=Resample.Quality(2)),
GSMFullRateCompressor(quality=Resample.Quality(3)),
GSMFullRateCompressor(quality=Resample.Quality(4)),
],
"resample": [
Resample(target_sample_rate=8000),
Resample(target_sample_rate=12000),
],
"distortion": [
Distortion(drive_db=100.0),
Distortion(drive_db=75.0),
Distortion(drive_db=50.0),
Distortion(drive_db=25.0),
Distortion(drive_db=5.0),
Distortion(drive_db=10.0),
Distortion(drive_db=1.0),
],
"filter": [
LowpassFilter(cutoff_frequency_hz=3000.0),
LowpassFilter(cutoff_frequency_hz=4000.0),
LowpassFilter(cutoff_frequency_hz=5000.0),
LowpassFilter(cutoff_frequency_hz=6000.0),
HighpassFilter(cutoff_frequency_hz=500.0),
HighpassFilter(cutoff_frequency_hz=200.0),
],
"chorus": [
Chorus(rate_hz=1.0, depth=0.5, centre_delay_ms=7.0, feedback=0.0, mix=0.5)
],
"40s": [
LowpassFilter(cutoff_frequency_hz=4000.0),
HighpassFilter(cutoff_frequency_hz=400.0),
Resample(target_sample_rate=8000),
],
"50s": [
LowpassFilter(cutoff_frequency_hz=8000.0),
HighpassFilter(cutoff_frequency_hz=200.0),
Resample(target_sample_rate=8000),
],
"mp3": [
HighpassFilter(cutoff_frequency_hz=120.0),
LowpassFilter(cutoff_frequency_hz=10_000.0),
Compressor(threshold_db=-10, ratio=4),
Bitcrush(bit_depth=16.0),
Resample(target_sample_rate=8000),
],
}
ambient_noise_data = load_ambient_noise_data()
def apply_vst_effects(
audio: np.ndarray,
sample_rate: int = 24000,
min_effects: int = 0,
max_effects: int = 2,
) -> np.ndarray:
try:
# Randomly select a number of effect categories to apply (between min_effects and max_effects)
num_effects = random.randint(min_effects, max_effects)
if num_effects == 0:
return audio
chosen_categories = random.sample(list(EFFECTS.keys()), num_effects)
# Randomly pick one effect from each chosen category
chosen_effects = [
random.choice(EFFECTS[category]) for category in chosen_categories
]
# Apply the chosen effects to the audio
board = Pedalboard(chosen_effects)
effected_audio = board(audio, sample_rate)
except Exception as e:
logger.error(f"Error in apply_vst_effects: {str(e)}")
return audio
return effected_audio
def collate_fn(batch):
# Filter out None values and ensure all items are dictionaries
batch = [item for item in batch if item is not None and isinstance(item, dict)]
if len(batch) == 0:
return None
# Separate the data into different lists
original_mels = []
processed_mels = []
mel_lengths = []
for sample in batch:
if (
"original_mel" in sample
and "processed_mel" in sample
and "mel_length" in sample
):
original_mels.append(sample["original_mel"])
processed_mels.append(sample["processed_mel"])
mel_lengths.append(sample["mel_length"])
if len(original_mels) == 0:
return None
# Find max length in this batch
max_len = max(mel_lengths)
# Pad sequences to max length
original_mels_padded = torch.nn.utils.rnn.pad_sequence(
[mel.transpose(0, 1) for mel in original_mels],
batch_first=True,
padding_value=0,
).transpose(1, 2)
processed_mels_padded = torch.nn.utils.rnn.pad_sequence(
[mel.transpose(0, 1) for mel in processed_mels],
batch_first=True,
padding_value=0,
).transpose(1, 2)
# Create mask for padded regions
mask = torch.arange(max_len).expand(len(mel_lengths), max_len) < torch.tensor(
mel_lengths
).unsqueeze(1)
return {
"original_mel": original_mels_padded,
"processed_mel": processed_mels_padded,
"mel_length": torch.tensor(mel_lengths),
"mask": mask,
}
class VoiceEnhancementDataset(Dataset):
def __init__(
self,
hf_dataset,
target_sample_rate=24000,
effects_function=None,
downsample_rate=8000,
max_seq_len=1024,
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.downsample_rate = downsample_rate
self.effects_function = effects_function or apply_vst_effects
# Initialize time and frequency masking
self.time_mask = TimeMasking(time_mask_param=200)
self.freq_mask = FrequencyMasking(freq_mask_param=60)
self.max_seq_len = max_seq_len # Set max_seq_len here
def __len__(self):
return len(self.data)
def get_ambient_noise(self):
return next(iter(ambient_noise_data))
def process_audio(self, audio, sample_rate):
if sample_rate != self.target_sample_rate:
audio = torchaudio.functional.resample(
audio, sample_rate, self.target_sample_rate
)
return audio
def handle_nans(self, tensor, name):
if torch.isnan(tensor).any():
logger.warning(f"NaN detected in {name}. Replacing with zeros.")
tensor = torch.nan_to_num(tensor, nan=0.0)
return tensor
def generate_non_affected_sample(self, processed_audio, original_mel):
return {
"original_mel": original_mel,
"processed_mel": original_mel,
"mel_length": original_mel.shape[-1],
}
def generate_4khz_downsample(self, processed_audio, original_mel):
processed_audio = torchaudio.functional.resample(
processed_audio, 4000, self.target_sample_rate
)
processed_audio = torch.clamp(processed_audio, min=-1.0, max=1.0)
processed_audio = self.handle_nans(processed_audio, "processed_audio")
processed_mel = get_mel_spectrogram(
processed_audio.unsqueeze(0), config_h
).squeeze(0)
processed_mel = self.handle_nans(processed_mel, "processed_mel")
return {
"original_mel": original_mel,
"processed_mel": processed_mel,
"mel_length": original_mel.shape[-1],
}
def generate_5khz_downsample(self, processed_audio, original_mel):
processed_audio = torchaudio.functional.resample(
processed_audio, 5000, self.target_sample_rate
)
processed_audio = torch.clamp(processed_audio, min=-1.0, max=1.0)
processed_audio = self.handle_nans(processed_audio, "processed_audio")
processed_mel = get_mel_spectrogram(
processed_audio.unsqueeze(0), config_h
).squeeze(0)
processed_mel = self.handle_nans(processed_mel, "processed_mel")
return {
"original_mel": original_mel,
"processed_mel": processed_mel,
"mel_length": original_mel.shape[-1],
}
def __getitem__(self, index):
try:
row = self.data[index]
audio = torch.tensor(row["audio"]["array"], dtype=torch.float32)
sample_rate = row["audio"]["sampling_rate"]
processed_audio = self.process_audio(audio, sample_rate)
# Truncate the audio to max_seq_len before further processing
max_audio_len = (
self.max_seq_len * config_h.hop_size
) # Convert max_seq_len to number of audio samples
if len(processed_audio) > max_audio_len:
processed_audio = processed_audio[:max_audio_len]
processed_audio = self.handle_nans(processed_audio, "processed_audio")
original_mel = get_mel_spectrogram(
processed_audio.unsqueeze(0), config_h
).squeeze(0)
original_mel = self.handle_nans(original_mel, "original_mel")
# List of all possible degradation methods
degradation_methods = [
self.generate_non_affected_sample,
self.generate_effects_sample,
self.generate_time_masked_effects,
self.generate_freq_masked_effects,
self.generate_noise_mixed_sample,
self.generate_choppy_audio,
self.generate_ambient_noise_mixed_sample,
self.generate_4khz_downsample,
self.generate_5khz_downsample,
]
# Randomly select one degradation method
chosen_method = random.choice(degradation_methods)
# Apply the chosen degradation
result = chosen_method(processed_audio, original_mel)
if result is None:
# If the chosen method fails, fall back to the original mel spectrogram
logger.error("empty result")
result = {
"original_mel": original_mel,
"processed_mel": original_mel,
"mel_length": original_mel.shape[-1],
}
min_length = min(
result["original_mel"].shape[-1], result["processed_mel"].shape[-1]
)
result["original_mel"] = result["original_mel"][:, :min_length]
result["processed_mel"] = result["processed_mel"][:, :min_length]
return result
except Exception as e:
logger.error(f"Error in __getitem__ at index {index}: {str(e)}")
# Return a default value instead of None
return {
"original_mel": torch.zeros((config_h.num_mels, 1)),
"processed_mel": torch.zeros((config_h.num_mels, 1)),
"mel_length": 1,
}
def generate_ambient_noise_mixed_sample(self, speech, original_mel):
try:
ambient_noise = self.get_ambient_noise()
mixed_audio = mix_speech_with_ambient_noise(speech, ambient_noise)
mixed_mel = get_mel_spectrogram(mixed_audio.unsqueeze(0), config_h).squeeze(
0
)
mixed_mel = self.handle_nans(mixed_mel, "mixed_mel")
return {
"original_mel": original_mel,
"processed_mel": mixed_mel,
"mel_length": original_mel.size(-1),
}
except Exception as e:
logger.error(f"Error in generate_ambient_noise_mixed_sample: {str(e)}")
return None
def generate_noise_mixed_sample(self, processed_audio, original_mel):
try:
audio_duration = len(processed_audio) / self.target_sample_rate
noise_type = random.choice(
[
"city",
"crowd",
"speech-like",
"nature",
"office",
"restaurant",
"construction",
"traffic",
"low_quality_mp3",
"low_quality_vorbis",
"vintage_1940s",
"vintage_1960s",
]
)
noise = generate_acoustic_noise(
audio_duration, self.target_sample_rate, noise_type
)
# Debug: Check the shape of noise
logger.debug(f"Original noise shape: {noise.shape}")
# Ensure noise is 1D. If multi-channel, flatten or select a single channel.
if noise.ndim > 1:
noise = noise.flatten()
logger.debug(f"Flattened noise shape: {noise.shape}")
# Ensure noise and processed_audio have the same length
processed_length = int(len(processed_audio))
noise_length = int(len(noise))
# Log types and values
logger.debug(
f"processed_length type: {type(processed_length)}, value: {processed_length}"
)
logger.debug(
f"noise_length type: {type(noise_length)}, value: {noise_length}"
)
if noise_length > processed_length:
noise = noise[:processed_length]
logger.debug(f"Noise truncated to length: {processed_length}")
elif noise_length < processed_length:
pad_width = processed_length - noise_length
pad_width = int(pad_width) # Ensure pad_width is an int
# For 1D noise, pad_width should be a tuple of two integers
noise = np.pad(noise, (0, pad_width), mode="constant")
logger.debug(f"Noise padded to length: {len(noise)}")
# Mix noise with processed audio
mix_ratio = random.uniform(0.001, 0.5) # Adjust this range as needed
mixed_audio = (1 - mix_ratio) * processed_audio.numpy() + mix_ratio * noise
# Debug: Check mixed_audio shape and type
logger.debug(
f"Mixed audio shape: {mixed_audio.shape}, dtype: {mixed_audio.dtype}"
)
mixed_audio = torch.from_numpy(mixed_audio).float()
mixed_audio = torch.clamp(mixed_audio, min=-1.0, max=1.0)
mixed_audio = self.handle_nans(mixed_audio, "mixed_audio")
mixed_mel = get_mel_spectrogram(mixed_audio.unsqueeze(0), config_h).squeeze(
0
)
mixed_mel = self.handle_nans(mixed_mel, "mixed_mel")
return {
"original_mel": original_mel,
"processed_mel": mixed_mel,
"mel_length": original_mel.shape[-1],
}
except Exception as e:
logger.error(f"Error in generate_noise_mixed_sample: {str(e)}")
return None
def generate_effects_sample(self, processed_audio, original_mel):
try:
audio_np = processed_audio.clone().numpy()
effects_audio = self.effects_function(audio_np, self.target_sample_rate)
effects_audio = torch.from_numpy(effects_audio)
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0)
effects_audio = self.handle_nans(effects_audio, "effects_audio")
effects_mel = get_mel_spectrogram(
effects_audio.unsqueeze(0), config_h
).squeeze(0)
effects_mel = self.handle_nans(effects_mel, "effects_mel")
return {
"original_mel": original_mel,
"processed_mel": effects_mel,
"mel_length": original_mel.shape[-1],
}
except Exception as e:
logger.error(f"Error in generate_effects_sample: {str(e)}")
return None
def generate_time_masked_effects(self, processed_audio, original_mel):
try:
audio_np = processed_audio.clone().numpy()
effects_audio = self.effects_function(audio_np, self.target_sample_rate)
effects_audio = torch.from_numpy(effects_audio)
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0)
effects_audio = self.handle_nans(effects_audio, "effects_audio")
effects_mel = get_mel_spectrogram(
effects_audio.unsqueeze(0), config_h
).squeeze(0)
time_masked_effects_mel = self.time_mask(effects_mel.unsqueeze(0)).squeeze(
0
)
time_masked_effects_mel = self.handle_nans(
time_masked_effects_mel, "time_masked_effects_mel"
)
return {
"original_mel": original_mel,
"processed_mel": time_masked_effects_mel,
"mel_length": original_mel.shape[-1],
}
except Exception as e:
logger.error(f"Error in generate_time_masked_effects: {str(e)}")
return None
def generate_freq_masked_effects(self, processed_audio, original_mel):
try:
audio_np = processed_audio.clone().numpy()
effects_audio = self.effects_function(audio_np, self.target_sample_rate)
effects_audio = torch.from_numpy(effects_audio)
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0)
effects_audio = self.handle_nans(effects_audio, "effects_audio")
effects_mel = get_mel_spectrogram(
effects_audio.unsqueeze(0), config_h
).squeeze(0)
freq_masked_effects_mel = self.freq_mask(effects_mel.unsqueeze(0)).squeeze(
0
)
freq_masked_effects_mel = self.handle_nans(
freq_masked_effects_mel, "freq_masked_effects_mel"
)
return {
"original_mel": original_mel,
"processed_mel": freq_masked_effects_mel,
"mel_length": original_mel.shape[-1],
}
except Exception as e:
logger.error(f"Error in generate_freq_masked_effects: {str(e)}")
return None
def generate_choppy_audio(self, processed_audio, original_mel):
try:
# Convert to numpy array for easier manipulation
audio_np = processed_audio.numpy().copy()
# Define parameters for choppiness
min_cut_length = int(0.01 * self.target_sample_rate) # 10ms
max_cut_length = int(0.05 * self.target_sample_rate) # 50ms
max_total_zero = int(len(audio_np) * 0.2) # Max 20% of audio zeroed
total_zero = 0
attempts = 0
max_attempts = 1000 # To prevent infinite loops
while total_zero < max_total_zero and attempts < max_attempts:
cut_length = random.randint(min_cut_length, max_cut_length)
cut_start = random.randint(0, len(audio_np) - cut_length)
# Check if this segment is already silent to avoid overlapping
if np.all(audio_np[cut_start : cut_start + cut_length] == 0.0):
attempts += 1
continue
# Apply the cut
audio_np[cut_start : cut_start + cut_length] = 0.0
total_zero += cut_length
attempts += 1
if total_zero == 0:
logger.warning("No cuts were applied. The audio remains unchanged.")
# Convert back to torch tensor
choppy_audio = torch.from_numpy(audio_np).float()
choppy_audio = torch.clamp(choppy_audio, min=-1.0, max=1.0)
choppy_audio = self.handle_nans(choppy_audio, "choppy_audio")
choppy_mel = get_mel_spectrogram(
choppy_audio.unsqueeze(0), config_h
).squeeze(0)
choppy_mel = self.handle_nans(choppy_mel, "choppy_mel")
return {
"original_mel": original_mel,
"processed_mel": choppy_mel,
"mel_length": original_mel.shape[-1],
}
except Exception as e:
logger.error(f"Error in generate_choppy_audio: {str(e)}")
return None
class VoiceRestoreTrainer:
def __init__(
self,
model,
optimizer,
num_warmup_steps=20000,
grad_accumulation_steps=32,
checkpoint_path=None,
log_file="logs.txt",
max_grad_norm=1.0,
sample_rate=24000,
tensorboard_log_dir="runs/e2_tts_experiment",
accelerate_kwargs=None,
ema_kwargs=None,
):
self.grad_accumulation_steps = grad_accumulation_steps
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(
log_with="all",
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=self.grad_accumulation_steps,
**(accelerate_kwargs or {}),
)
self.device = self.accelerator.device
self.target_sample_rate = sample_rate
self.model = model.to(self.device)
if self.is_main:
self.ema_model = EMA(
model, include_online_model=False, **(ema_kwargs or {})
)
self.ema_model.to(self.device)
self.ema_model = self.accelerator.prepare(
self.ema_model
) # If EMA supports Accelerator
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
self.checkpoint_path = checkpoint_path or "model.pth"
self.max_grad_norm = max_grad_norm
self.writer = SummaryWriter(log_dir=tensorboard_log_dir)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
@property
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, step, finetune=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.optimizer.state_dict(), # Corrected
ema_model_state_dict=self.ema_model.state_dict(),
step=step,
)
self.accelerator.save(checkpoint, self.checkpoint_path)
def load_checkpoint(self):
if not os.path.exists(self.checkpoint_path):
return 0
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
self.accelerator.unwrap_model(self.model).load_state_dict(
checkpoint["model_state_dict"]
)
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Corrected
if self.is_main:
self.ema_model.load_state_dict(
checkpoint["ema_model_state_dict"]
) # If EMA supports Accelerator
return checkpoint["step"]
def plot_batch_comparison_spectrograms(
self,
original_mel: torch.Tensor,
processed_mel: torch.Tensor,
predicted_mel: torch.Tensor,
title: str,
max_samples: int = 16, # Limit the number of samples to plot
):
"""
Plots Original, Degraded, and Predicted Spectrograms side-by-side for each sample in the batch.
Args:
original_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames)
processed_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames)
predicted_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames)
title (str): Title for the entire figure.
max_samples (int, optional): Maximum number of samples to plot. Defaults to 16.
Returns:
matplotlib.figure.Figure: The generated figure.
"""
batch_size = original_mel.shape[0]
num_samples = min(batch_size, max_samples)
cols = 3 # Original, Degraded, Predicted
rows = num_samples
fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
fig.suptitle(title, fontsize=16)
# If there's only one sample, axes might not be a 2D array
if num_samples == 1:
axes = np.expand_dims(axes, axis=0)
for i in range(num_samples):
# Original Spectrogram
ax = axes[i, 0]
mel_spec = original_mel[i].detach().cpu().numpy()
im = ax.imshow(
mel_spec.T,
aspect="auto",
origin="lower",
interpolation="nearest",
cmap="viridis",
vmin=mel_spec.min(),
vmax=mel_spec.max(),
)
if i == 0:
ax.set_title("Original", fontsize=12)
ax.set_xlabel("Time")
ax.set_ylabel("Mel bins")
# Degraded Spectrogram
ax = axes[i, 1]
mel_spec = processed_mel[i].detach().cpu().numpy()
im = ax.imshow(
mel_spec.T,
aspect="auto",
origin="lower",
interpolation="nearest",
cmap="viridis",
vmin=mel_spec.min(),
vmax=mel_spec.max(),
)
if i == 0:
ax.set_title("Degraded", fontsize=12)
ax.set_xlabel("Time")
ax.set_ylabel("Mel bins")
# Predicted Spectrogram
ax = axes[i, 2]
mel_spec = predicted_mel[i].detach().cpu().numpy()
im = ax.imshow(
mel_spec.T,
aspect="auto",
origin="lower",
interpolation="nearest",
cmap="viridis",
vmin=mel_spec.min(),
vmax=mel_spec.max(),
)
if i == 0:
ax.set_title("Predicted", fontsize=12)
ax.set_xlabel("Time")
ax.set_ylabel("Mel bins")
# Adjust layout and add a single colorbar
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Make room for the suptitle
# Add a single colorbar to the right of the subplots
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # [left, bottom, width, height]
fig.colorbar(im, cax=cbar_ax, orientation="vertical", label="Amplitude")
return fig
def train(self, train_dataset, epochs, batch_size, num_workers=2, save_step=100):
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
train_dataloader = self.accelerator.prepare(train_dataloader)
start_step = self.load_checkpoint()
global_step = start_step
for epoch in range(epochs):
self.model.train()
progress_bar = tqdm(
train_dataloader,
desc=f"Epoch {epoch+1}/{epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
)
epoch_loss = 0.0
accumulated_loss = 0.0
for batch in progress_bar:
if batch is None:
global_step += 1
continue # Skip the batch
with self.accelerator.accumulate(self.model):
original_mel = batch["original_mel"].to(self.device)
processed_mel = batch["processed_mel"].to(self.device)
mask = batch["mask"].to(self.device)
original_mel = rearrange(original_mel, "b d n -> b n d")
processed_mel = rearrange(processed_mel, "b d n -> b n d")
loss, pred = self.model(original_mel, processed_mel, mask)
accumulated_loss += loss.item()
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
if self.accelerator.sync_gradients:
avg_loss = accumulated_loss / self.grad_accumulation_steps
if self.accelerator.is_main_process:
logger.info(f"step {global_step}: loss = {avg_loss:.4f}")
self.writer.add_scalar("loss", avg_loss, global_step)
accumulated_loss = 0.0
self.optimizer.step()
self.optimizer.zero_grad()
if self.is_main:
self.ema_model.update()
global_step += 1
epoch_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
if global_step % save_step == 0:
self.save_checkpoint(global_step)
self.writer.flush()
self.writer.close()
---
def forward(self, original, processed, mask):
times = torch.rand((original.size(0),), device=original.device)
t = rearrange(times, "b -> b 1 1")
# Get the maximum length in the batch
max_length = max(original.size(1), processed.size(1))
# Pad original and processed to max_length
original_padded = F.pad(original, (0, 0, 0, max_length - original.size(1)))
processed_padded = F.pad(processed, (0, 0, 0, max_length - processed.size(1)))
# Extend mask to max_length (0 for padded elements)
mask_extended = F.pad(mask, (0, max_length - mask.size(1)))
x0 = torch.randn_like(processed_padded)
x1 = original_padded
# Compute the noisy intermediate representation `w` using time step `t`
w = (1.0 - t) * x0 + t * x1
# Compute the flow as the difference between the original data and `w`
flow = x1 - x0
# Predict the vector field with transformer-based architecture
pred = self.transformer_with_pred_head(
w, times=times, cond=processed_padded, mask=mask_extended
)
# Apply mask to both pred and flow (only compute loss for non-padded elements)
pred = pred * mask_extended.unsqueeze(-1)
flow = flow * mask_extended.unsqueeze(-1)
# Calculate MSE loss on non-padded elements
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss * mask_extended.unsqueeze(-1)
num_unmasked = mask_extended.sum() * pred.size(-1)
# Avoid division by zero
if num_unmasked == 0:
loss = torch.tensor(0.0, device=loss.device)
else:
loss = loss.sum() / num_unmasked
return loss, pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment