-
-
Save skirdey/4c90202ee4aa753a0184f4366953b60a to your computer and use it in GitHub Desktop.
voicerestore training code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
import torch | |
from datasets import get_dataset_config_names, load_dataset | |
from schedulefree import AdamWScheduleFree | |
from torch.utils.data import ConcatDataset | |
from trainer import VoiceEnhancementDataset, VoiceRestoreTrainer | |
from voice_restore import VoiceRestore | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
print("CUDA Available:", torch.cuda.is_available()) | |
if torch.cuda.is_available(): | |
print("CUDA Version:", torch.version.cuda) | |
print("PyTorch CUDA Runtime Version:", torch._C._cuda_getCompiledVersion()) | |
print("Device Name:", torch.cuda.get_device_name(0)) | |
def human_readable_number(num): | |
for unit in ["", "thousand", "million", "billion", "trillion"]: | |
if abs(num) < 1000.0: | |
return f"{num:3.1f} {unit}" | |
num /= 1000.0 | |
return f"{num:.1f} quadrillion" | |
def count_parameters(model): | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
return total_params, trainable_params | |
# Define a function to load and split datasets | |
def load_and_split_dataset( | |
dataset_name, | |
config_name=None, | |
split="train", | |
cache_dir=None, | |
split_ratio=0.00001, | |
token=None, | |
): | |
if token: | |
dataset = load_dataset( | |
dataset_name, | |
config_name, | |
split=split, | |
cache_dir=cache_dir, | |
num_proc=4, | |
save_infos=True, | |
token=token, | |
trust_remote_code=True, | |
verification_mode="no_checks", | |
).train_test_split(test_size=split_ratio) | |
else: | |
dataset = load_dataset( | |
dataset_name, | |
config_name, | |
split=split, | |
cache_dir=cache_dir, | |
num_proc=4, | |
save_infos=True, | |
trust_remote_code=True, | |
verification_mode="no_checks", | |
).train_test_split(test_size=split_ratio) | |
train_dataset = VoiceEnhancementDataset(dataset["train"]) | |
val_dataset = VoiceEnhancementDataset(dataset["test"]) | |
return train_dataset, val_dataset | |
if __name__ == "__main__": | |
from multiprocessing import freeze_support | |
freeze_support() | |
print("Starting main execution") | |
parser = argparse.ArgumentParser(description="Voice Restore Training Script") | |
parser.add_argument( | |
"--quick_run", | |
action="store_true", | |
help="Run with a single small dataset for quick testing", | |
) | |
args = parser.parse_args() | |
train_dataset = None | |
if args.quick_run: | |
print("Running in quick mode with a single dataset") | |
expresso_train, _ = load_and_split_dataset("ylacombe/expresso", split="train") | |
train_dataset = expresso_train | |
grad_accumulation_steps = 1 | |
epochs = 1000 | |
batch_size = 2 | |
# Start training | |
try: | |
print("Initializing trainer") | |
voice_restore = VoiceRestore( | |
num_channels=100, | |
transformer=dict( | |
dim=768, | |
depth=20, | |
skip_connect_type="concat", | |
heads=16, | |
dim_head=64, | |
max_seq_len=2000, | |
), | |
sigma=0.0, | |
) | |
total_params, trainable_params = count_parameters(voice_restore) | |
print(f"Total parameters: {human_readable_number(total_params)}") | |
print(f"Trainable parameters: {human_readable_number(trainable_params)}") | |
optimizer = AdamWScheduleFree( | |
voice_restore.parameters(), lr=3e-5, warmup_steps=5000 | |
) | |
trainer = VoiceRestoreTrainer( | |
voice_restore, | |
optimizer, | |
grad_accumulation_steps=grad_accumulation_steps, | |
checkpoint_path="./checkpoints/voicerestore_9_28_24.pt", | |
log_file="./logs.txt", | |
accelerate_kwargs=dict(mixed_precision="bf16"), | |
) | |
print("Trainer initialized successfully") | |
print("Starting training") | |
trainer.train(train_dataset, epochs, batch_size, save_step=5_000) | |
print("Training completed successfully") | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
--- | |
from __future__ import annotations | |
import os | |
import random | |
from types import SimpleNamespace | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.utils.checkpoint | |
import torchaudio | |
from accelerate import Accelerator | |
from accelerate.utils import DistributedDataParallelKwargs | |
from einops import rearrange | |
from ema_pytorch import EMA | |
from loguru import logger | |
from pedalboard import ( | |
Bitcrush, | |
Chorus, | |
Compressor, | |
Distortion, | |
GSMFullRateCompressor, | |
HighpassFilter, | |
LowpassFilter, | |
MP3Compressor, | |
Pedalboard, | |
Resample, | |
Reverb, | |
) | |
from torch.utils.data import Dataset | |
from torch.utils.tensorboard import SummaryWriter | |
from torchaudio.transforms import FrequencyMasking, TimeMasking | |
from tqdm import tqdm | |
from ambient_noise import ( | |
load_ambient_noise_data, | |
mix_speech_with_ambient_noise, | |
) | |
from gen_noise import generate_acoustic_noise | |
from meldataset import get_mel_spectrogram | |
config_h = { | |
"resblock": "1", | |
"num_gpus": 0, | |
"batch_size": 32, | |
"learning_rate": 0.0001, | |
"adam_b1": 0.8, | |
"adam_b2": 0.99, | |
"lr_decay": 0.9999996, | |
"seed": 1234, | |
"upsample_rates": [4, 4, 2, 2, 2, 2], | |
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4], | |
"upsample_initial_channel": 1536, | |
"resblock_kernel_sizes": [3, 7, 11], | |
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
"use_tanh_at_final": False, | |
"use_bias_at_final": False, | |
"activation": "snakebeta", | |
"snake_logscale": True, | |
"use_cqtd_instead_of_mrd": True, | |
"cqtd_filters": 128, | |
"cqtd_max_filters": 1024, | |
"cqtd_filters_scale": 1, | |
"cqtd_dilations": [1, 2, 4], | |
"cqtd_hop_lengths": [512, 256, 256], | |
"cqtd_n_octaves": [9, 9, 9], | |
"cqtd_bins_per_octaves": [24, 36, 48], | |
"mpd_reshapes": [2, 3, 5, 7, 11], | |
"use_spectral_norm": False, | |
"discriminator_channel_mult": 1, | |
"use_multiscale_melloss": True, | |
"lambda_melloss": 15, | |
"clip_grad_norm": 500, | |
"segment_size": 65536, | |
"num_mels": 100, | |
"num_freq": 1025, | |
"n_fft": 1024, | |
"hop_size": 256, | |
"win_size": 1024, | |
"sampling_rate": 24000, | |
"fmin": 0, | |
"fmax": None, | |
"fmax_for_loss": None, | |
"normalize_volume": True, | |
"num_workers": 4, | |
"dist_config": { | |
"dist_backend": "nccl", | |
"dist_url": "tcp://localhost:54321", | |
"world_size": 1, | |
}, | |
} | |
config_h = SimpleNamespace(**config_h) | |
def plot_spectrogram(spectrogram): | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none") | |
plt.colorbar(im, ax=ax) | |
plt.xlabel("Frames") | |
plt.ylabel("Channels") | |
plt.tight_layout() | |
fig.canvas.draw() | |
plt.close() | |
return fig | |
def exists(v): | |
return v is not None | |
def default(v, d): | |
return v if exists(v) else d | |
EFFECTS = { | |
"mp3lame": [ | |
MP3Compressor(vbr_quality=5.0), | |
MP3Compressor(vbr_quality=6.0), | |
MP3Compressor(vbr_quality=7.0), | |
MP3Compressor(vbr_quality=8.0), | |
MP3Compressor(vbr_quality=9.0), | |
], | |
"reverb": [ | |
Reverb( | |
room_size=0.5, | |
damping=0.5, | |
wet_level=0.33, | |
dry_level=0.4, | |
width=1.0, | |
freeze_mode=0.0, | |
), | |
Reverb( | |
room_size=0.2, | |
damping=0.9, | |
wet_level=0.66, | |
dry_level=0.2, | |
width=0.5, | |
freeze_mode=0.0, | |
), | |
Reverb( | |
room_size=0.8, | |
damping=0.2, | |
wet_level=0.6, | |
dry_level=0.4, | |
width=1.0, | |
freeze_mode=0.0, | |
), | |
Reverb(room_size=0.1), | |
Reverb(room_size=0.2), | |
Reverb(room_size=0.3), | |
Reverb(room_size=0.4), | |
Reverb(room_size=0.5), | |
], | |
"bitcrush": [ | |
Bitcrush(bit_depth=4.0), | |
Bitcrush(bit_depth=8.0), | |
Bitcrush(bit_depth=16.0), | |
], | |
"compressor": [ | |
Compressor(threshold_db=-20, ratio=3), | |
GSMFullRateCompressor(), | |
GSMFullRateCompressor(quality=Resample.Quality(0)), | |
GSMFullRateCompressor(quality=Resample.Quality(1)), | |
GSMFullRateCompressor(quality=Resample.Quality(2)), | |
GSMFullRateCompressor(quality=Resample.Quality(3)), | |
GSMFullRateCompressor(quality=Resample.Quality(4)), | |
], | |
"resample": [ | |
Resample(target_sample_rate=8000), | |
Resample(target_sample_rate=12000), | |
], | |
"distortion": [ | |
Distortion(drive_db=100.0), | |
Distortion(drive_db=75.0), | |
Distortion(drive_db=50.0), | |
Distortion(drive_db=25.0), | |
Distortion(drive_db=5.0), | |
Distortion(drive_db=10.0), | |
Distortion(drive_db=1.0), | |
], | |
"filter": [ | |
LowpassFilter(cutoff_frequency_hz=3000.0), | |
LowpassFilter(cutoff_frequency_hz=4000.0), | |
LowpassFilter(cutoff_frequency_hz=5000.0), | |
LowpassFilter(cutoff_frequency_hz=6000.0), | |
HighpassFilter(cutoff_frequency_hz=500.0), | |
HighpassFilter(cutoff_frequency_hz=200.0), | |
], | |
"chorus": [ | |
Chorus(rate_hz=1.0, depth=0.5, centre_delay_ms=7.0, feedback=0.0, mix=0.5) | |
], | |
"40s": [ | |
LowpassFilter(cutoff_frequency_hz=4000.0), | |
HighpassFilter(cutoff_frequency_hz=400.0), | |
Resample(target_sample_rate=8000), | |
], | |
"50s": [ | |
LowpassFilter(cutoff_frequency_hz=8000.0), | |
HighpassFilter(cutoff_frequency_hz=200.0), | |
Resample(target_sample_rate=8000), | |
], | |
"mp3": [ | |
HighpassFilter(cutoff_frequency_hz=120.0), | |
LowpassFilter(cutoff_frequency_hz=10_000.0), | |
Compressor(threshold_db=-10, ratio=4), | |
Bitcrush(bit_depth=16.0), | |
Resample(target_sample_rate=8000), | |
], | |
} | |
ambient_noise_data = load_ambient_noise_data() | |
def apply_vst_effects( | |
audio: np.ndarray, | |
sample_rate: int = 24000, | |
min_effects: int = 0, | |
max_effects: int = 2, | |
) -> np.ndarray: | |
try: | |
# Randomly select a number of effect categories to apply (between min_effects and max_effects) | |
num_effects = random.randint(min_effects, max_effects) | |
if num_effects == 0: | |
return audio | |
chosen_categories = random.sample(list(EFFECTS.keys()), num_effects) | |
# Randomly pick one effect from each chosen category | |
chosen_effects = [ | |
random.choice(EFFECTS[category]) for category in chosen_categories | |
] | |
# Apply the chosen effects to the audio | |
board = Pedalboard(chosen_effects) | |
effected_audio = board(audio, sample_rate) | |
except Exception as e: | |
logger.error(f"Error in apply_vst_effects: {str(e)}") | |
return audio | |
return effected_audio | |
def collate_fn(batch): | |
# Filter out None values and ensure all items are dictionaries | |
batch = [item for item in batch if item is not None and isinstance(item, dict)] | |
if len(batch) == 0: | |
return None | |
# Separate the data into different lists | |
original_mels = [] | |
processed_mels = [] | |
mel_lengths = [] | |
for sample in batch: | |
if ( | |
"original_mel" in sample | |
and "processed_mel" in sample | |
and "mel_length" in sample | |
): | |
original_mels.append(sample["original_mel"]) | |
processed_mels.append(sample["processed_mel"]) | |
mel_lengths.append(sample["mel_length"]) | |
if len(original_mels) == 0: | |
return None | |
# Find max length in this batch | |
max_len = max(mel_lengths) | |
# Pad sequences to max length | |
original_mels_padded = torch.nn.utils.rnn.pad_sequence( | |
[mel.transpose(0, 1) for mel in original_mels], | |
batch_first=True, | |
padding_value=0, | |
).transpose(1, 2) | |
processed_mels_padded = torch.nn.utils.rnn.pad_sequence( | |
[mel.transpose(0, 1) for mel in processed_mels], | |
batch_first=True, | |
padding_value=0, | |
).transpose(1, 2) | |
# Create mask for padded regions | |
mask = torch.arange(max_len).expand(len(mel_lengths), max_len) < torch.tensor( | |
mel_lengths | |
).unsqueeze(1) | |
return { | |
"original_mel": original_mels_padded, | |
"processed_mel": processed_mels_padded, | |
"mel_length": torch.tensor(mel_lengths), | |
"mask": mask, | |
} | |
class VoiceEnhancementDataset(Dataset): | |
def __init__( | |
self, | |
hf_dataset, | |
target_sample_rate=24000, | |
effects_function=None, | |
downsample_rate=8000, | |
max_seq_len=1024, | |
): | |
self.data = hf_dataset | |
self.target_sample_rate = target_sample_rate | |
self.downsample_rate = downsample_rate | |
self.effects_function = effects_function or apply_vst_effects | |
# Initialize time and frequency masking | |
self.time_mask = TimeMasking(time_mask_param=200) | |
self.freq_mask = FrequencyMasking(freq_mask_param=60) | |
self.max_seq_len = max_seq_len # Set max_seq_len here | |
def __len__(self): | |
return len(self.data) | |
def get_ambient_noise(self): | |
return next(iter(ambient_noise_data)) | |
def process_audio(self, audio, sample_rate): | |
if sample_rate != self.target_sample_rate: | |
audio = torchaudio.functional.resample( | |
audio, sample_rate, self.target_sample_rate | |
) | |
return audio | |
def handle_nans(self, tensor, name): | |
if torch.isnan(tensor).any(): | |
logger.warning(f"NaN detected in {name}. Replacing with zeros.") | |
tensor = torch.nan_to_num(tensor, nan=0.0) | |
return tensor | |
def generate_non_affected_sample(self, processed_audio, original_mel): | |
return { | |
"original_mel": original_mel, | |
"processed_mel": original_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
def generate_4khz_downsample(self, processed_audio, original_mel): | |
processed_audio = torchaudio.functional.resample( | |
processed_audio, 4000, self.target_sample_rate | |
) | |
processed_audio = torch.clamp(processed_audio, min=-1.0, max=1.0) | |
processed_audio = self.handle_nans(processed_audio, "processed_audio") | |
processed_mel = get_mel_spectrogram( | |
processed_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
processed_mel = self.handle_nans(processed_mel, "processed_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": processed_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
def generate_5khz_downsample(self, processed_audio, original_mel): | |
processed_audio = torchaudio.functional.resample( | |
processed_audio, 5000, self.target_sample_rate | |
) | |
processed_audio = torch.clamp(processed_audio, min=-1.0, max=1.0) | |
processed_audio = self.handle_nans(processed_audio, "processed_audio") | |
processed_mel = get_mel_spectrogram( | |
processed_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
processed_mel = self.handle_nans(processed_mel, "processed_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": processed_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
def __getitem__(self, index): | |
try: | |
row = self.data[index] | |
audio = torch.tensor(row["audio"]["array"], dtype=torch.float32) | |
sample_rate = row["audio"]["sampling_rate"] | |
processed_audio = self.process_audio(audio, sample_rate) | |
# Truncate the audio to max_seq_len before further processing | |
max_audio_len = ( | |
self.max_seq_len * config_h.hop_size | |
) # Convert max_seq_len to number of audio samples | |
if len(processed_audio) > max_audio_len: | |
processed_audio = processed_audio[:max_audio_len] | |
processed_audio = self.handle_nans(processed_audio, "processed_audio") | |
original_mel = get_mel_spectrogram( | |
processed_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
original_mel = self.handle_nans(original_mel, "original_mel") | |
# List of all possible degradation methods | |
degradation_methods = [ | |
self.generate_non_affected_sample, | |
self.generate_effects_sample, | |
self.generate_time_masked_effects, | |
self.generate_freq_masked_effects, | |
self.generate_noise_mixed_sample, | |
self.generate_choppy_audio, | |
self.generate_ambient_noise_mixed_sample, | |
self.generate_4khz_downsample, | |
self.generate_5khz_downsample, | |
] | |
# Randomly select one degradation method | |
chosen_method = random.choice(degradation_methods) | |
# Apply the chosen degradation | |
result = chosen_method(processed_audio, original_mel) | |
if result is None: | |
# If the chosen method fails, fall back to the original mel spectrogram | |
logger.error("empty result") | |
result = { | |
"original_mel": original_mel, | |
"processed_mel": original_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
min_length = min( | |
result["original_mel"].shape[-1], result["processed_mel"].shape[-1] | |
) | |
result["original_mel"] = result["original_mel"][:, :min_length] | |
result["processed_mel"] = result["processed_mel"][:, :min_length] | |
return result | |
except Exception as e: | |
logger.error(f"Error in __getitem__ at index {index}: {str(e)}") | |
# Return a default value instead of None | |
return { | |
"original_mel": torch.zeros((config_h.num_mels, 1)), | |
"processed_mel": torch.zeros((config_h.num_mels, 1)), | |
"mel_length": 1, | |
} | |
def generate_ambient_noise_mixed_sample(self, speech, original_mel): | |
try: | |
ambient_noise = self.get_ambient_noise() | |
mixed_audio = mix_speech_with_ambient_noise(speech, ambient_noise) | |
mixed_mel = get_mel_spectrogram(mixed_audio.unsqueeze(0), config_h).squeeze( | |
0 | |
) | |
mixed_mel = self.handle_nans(mixed_mel, "mixed_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": mixed_mel, | |
"mel_length": original_mel.size(-1), | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_ambient_noise_mixed_sample: {str(e)}") | |
return None | |
def generate_noise_mixed_sample(self, processed_audio, original_mel): | |
try: | |
audio_duration = len(processed_audio) / self.target_sample_rate | |
noise_type = random.choice( | |
[ | |
"city", | |
"crowd", | |
"speech-like", | |
"nature", | |
"office", | |
"restaurant", | |
"construction", | |
"traffic", | |
"low_quality_mp3", | |
"low_quality_vorbis", | |
"vintage_1940s", | |
"vintage_1960s", | |
] | |
) | |
noise = generate_acoustic_noise( | |
audio_duration, self.target_sample_rate, noise_type | |
) | |
# Debug: Check the shape of noise | |
logger.debug(f"Original noise shape: {noise.shape}") | |
# Ensure noise is 1D. If multi-channel, flatten or select a single channel. | |
if noise.ndim > 1: | |
noise = noise.flatten() | |
logger.debug(f"Flattened noise shape: {noise.shape}") | |
# Ensure noise and processed_audio have the same length | |
processed_length = int(len(processed_audio)) | |
noise_length = int(len(noise)) | |
# Log types and values | |
logger.debug( | |
f"processed_length type: {type(processed_length)}, value: {processed_length}" | |
) | |
logger.debug( | |
f"noise_length type: {type(noise_length)}, value: {noise_length}" | |
) | |
if noise_length > processed_length: | |
noise = noise[:processed_length] | |
logger.debug(f"Noise truncated to length: {processed_length}") | |
elif noise_length < processed_length: | |
pad_width = processed_length - noise_length | |
pad_width = int(pad_width) # Ensure pad_width is an int | |
# For 1D noise, pad_width should be a tuple of two integers | |
noise = np.pad(noise, (0, pad_width), mode="constant") | |
logger.debug(f"Noise padded to length: {len(noise)}") | |
# Mix noise with processed audio | |
mix_ratio = random.uniform(0.001, 0.5) # Adjust this range as needed | |
mixed_audio = (1 - mix_ratio) * processed_audio.numpy() + mix_ratio * noise | |
# Debug: Check mixed_audio shape and type | |
logger.debug( | |
f"Mixed audio shape: {mixed_audio.shape}, dtype: {mixed_audio.dtype}" | |
) | |
mixed_audio = torch.from_numpy(mixed_audio).float() | |
mixed_audio = torch.clamp(mixed_audio, min=-1.0, max=1.0) | |
mixed_audio = self.handle_nans(mixed_audio, "mixed_audio") | |
mixed_mel = get_mel_spectrogram(mixed_audio.unsqueeze(0), config_h).squeeze( | |
0 | |
) | |
mixed_mel = self.handle_nans(mixed_mel, "mixed_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": mixed_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_noise_mixed_sample: {str(e)}") | |
return None | |
def generate_effects_sample(self, processed_audio, original_mel): | |
try: | |
audio_np = processed_audio.clone().numpy() | |
effects_audio = self.effects_function(audio_np, self.target_sample_rate) | |
effects_audio = torch.from_numpy(effects_audio) | |
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0) | |
effects_audio = self.handle_nans(effects_audio, "effects_audio") | |
effects_mel = get_mel_spectrogram( | |
effects_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
effects_mel = self.handle_nans(effects_mel, "effects_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": effects_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_effects_sample: {str(e)}") | |
return None | |
def generate_time_masked_effects(self, processed_audio, original_mel): | |
try: | |
audio_np = processed_audio.clone().numpy() | |
effects_audio = self.effects_function(audio_np, self.target_sample_rate) | |
effects_audio = torch.from_numpy(effects_audio) | |
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0) | |
effects_audio = self.handle_nans(effects_audio, "effects_audio") | |
effects_mel = get_mel_spectrogram( | |
effects_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
time_masked_effects_mel = self.time_mask(effects_mel.unsqueeze(0)).squeeze( | |
0 | |
) | |
time_masked_effects_mel = self.handle_nans( | |
time_masked_effects_mel, "time_masked_effects_mel" | |
) | |
return { | |
"original_mel": original_mel, | |
"processed_mel": time_masked_effects_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_time_masked_effects: {str(e)}") | |
return None | |
def generate_freq_masked_effects(self, processed_audio, original_mel): | |
try: | |
audio_np = processed_audio.clone().numpy() | |
effects_audio = self.effects_function(audio_np, self.target_sample_rate) | |
effects_audio = torch.from_numpy(effects_audio) | |
effects_audio = torch.clamp(effects_audio, min=-1.0, max=1.0) | |
effects_audio = self.handle_nans(effects_audio, "effects_audio") | |
effects_mel = get_mel_spectrogram( | |
effects_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
freq_masked_effects_mel = self.freq_mask(effects_mel.unsqueeze(0)).squeeze( | |
0 | |
) | |
freq_masked_effects_mel = self.handle_nans( | |
freq_masked_effects_mel, "freq_masked_effects_mel" | |
) | |
return { | |
"original_mel": original_mel, | |
"processed_mel": freq_masked_effects_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_freq_masked_effects: {str(e)}") | |
return None | |
def generate_choppy_audio(self, processed_audio, original_mel): | |
try: | |
# Convert to numpy array for easier manipulation | |
audio_np = processed_audio.numpy().copy() | |
# Define parameters for choppiness | |
min_cut_length = int(0.01 * self.target_sample_rate) # 10ms | |
max_cut_length = int(0.05 * self.target_sample_rate) # 50ms | |
max_total_zero = int(len(audio_np) * 0.2) # Max 20% of audio zeroed | |
total_zero = 0 | |
attempts = 0 | |
max_attempts = 1000 # To prevent infinite loops | |
while total_zero < max_total_zero and attempts < max_attempts: | |
cut_length = random.randint(min_cut_length, max_cut_length) | |
cut_start = random.randint(0, len(audio_np) - cut_length) | |
# Check if this segment is already silent to avoid overlapping | |
if np.all(audio_np[cut_start : cut_start + cut_length] == 0.0): | |
attempts += 1 | |
continue | |
# Apply the cut | |
audio_np[cut_start : cut_start + cut_length] = 0.0 | |
total_zero += cut_length | |
attempts += 1 | |
if total_zero == 0: | |
logger.warning("No cuts were applied. The audio remains unchanged.") | |
# Convert back to torch tensor | |
choppy_audio = torch.from_numpy(audio_np).float() | |
choppy_audio = torch.clamp(choppy_audio, min=-1.0, max=1.0) | |
choppy_audio = self.handle_nans(choppy_audio, "choppy_audio") | |
choppy_mel = get_mel_spectrogram( | |
choppy_audio.unsqueeze(0), config_h | |
).squeeze(0) | |
choppy_mel = self.handle_nans(choppy_mel, "choppy_mel") | |
return { | |
"original_mel": original_mel, | |
"processed_mel": choppy_mel, | |
"mel_length": original_mel.shape[-1], | |
} | |
except Exception as e: | |
logger.error(f"Error in generate_choppy_audio: {str(e)}") | |
return None | |
class VoiceRestoreTrainer: | |
def __init__( | |
self, | |
model, | |
optimizer, | |
num_warmup_steps=20000, | |
grad_accumulation_steps=32, | |
checkpoint_path=None, | |
log_file="logs.txt", | |
max_grad_norm=1.0, | |
sample_rate=24000, | |
tensorboard_log_dir="runs/e2_tts_experiment", | |
accelerate_kwargs=None, | |
ema_kwargs=None, | |
): | |
self.grad_accumulation_steps = grad_accumulation_steps | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
self.accelerator = Accelerator( | |
log_with="all", | |
kwargs_handlers=[ddp_kwargs], | |
gradient_accumulation_steps=self.grad_accumulation_steps, | |
**(accelerate_kwargs or {}), | |
) | |
self.device = self.accelerator.device | |
self.target_sample_rate = sample_rate | |
self.model = model.to(self.device) | |
if self.is_main: | |
self.ema_model = EMA( | |
model, include_online_model=False, **(ema_kwargs or {}) | |
) | |
self.ema_model.to(self.device) | |
self.ema_model = self.accelerator.prepare( | |
self.ema_model | |
) # If EMA supports Accelerator | |
self.optimizer = optimizer | |
self.num_warmup_steps = num_warmup_steps | |
self.checkpoint_path = checkpoint_path or "model.pth" | |
self.max_grad_norm = max_grad_norm | |
self.writer = SummaryWriter(log_dir=tensorboard_log_dir) | |
self.model, self.optimizer = self.accelerator.prepare( | |
self.model, self.optimizer | |
) | |
@property | |
def is_main(self): | |
return self.accelerator.is_main_process | |
def save_checkpoint(self, step, finetune=False): | |
self.accelerator.wait_for_everyone() | |
if self.is_main: | |
checkpoint = dict( | |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), | |
optimizer_state_dict=self.optimizer.state_dict(), # Corrected | |
ema_model_state_dict=self.ema_model.state_dict(), | |
step=step, | |
) | |
self.accelerator.save(checkpoint, self.checkpoint_path) | |
def load_checkpoint(self): | |
if not os.path.exists(self.checkpoint_path): | |
return 0 | |
checkpoint = torch.load(self.checkpoint_path, map_location=self.device) | |
self.accelerator.unwrap_model(self.model).load_state_dict( | |
checkpoint["model_state_dict"] | |
) | |
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Corrected | |
if self.is_main: | |
self.ema_model.load_state_dict( | |
checkpoint["ema_model_state_dict"] | |
) # If EMA supports Accelerator | |
return checkpoint["step"] | |
def plot_batch_comparison_spectrograms( | |
self, | |
original_mel: torch.Tensor, | |
processed_mel: torch.Tensor, | |
predicted_mel: torch.Tensor, | |
title: str, | |
max_samples: int = 16, # Limit the number of samples to plot | |
): | |
""" | |
Plots Original, Degraded, and Predicted Spectrograms side-by-side for each sample in the batch. | |
Args: | |
original_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames) | |
processed_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames) | |
predicted_mel (torch.Tensor): Tensor of shape (batch_size, num_mels, time_frames) | |
title (str): Title for the entire figure. | |
max_samples (int, optional): Maximum number of samples to plot. Defaults to 16. | |
Returns: | |
matplotlib.figure.Figure: The generated figure. | |
""" | |
batch_size = original_mel.shape[0] | |
num_samples = min(batch_size, max_samples) | |
cols = 3 # Original, Degraded, Predicted | |
rows = num_samples | |
fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows)) | |
fig.suptitle(title, fontsize=16) | |
# If there's only one sample, axes might not be a 2D array | |
if num_samples == 1: | |
axes = np.expand_dims(axes, axis=0) | |
for i in range(num_samples): | |
# Original Spectrogram | |
ax = axes[i, 0] | |
mel_spec = original_mel[i].detach().cpu().numpy() | |
im = ax.imshow( | |
mel_spec.T, | |
aspect="auto", | |
origin="lower", | |
interpolation="nearest", | |
cmap="viridis", | |
vmin=mel_spec.min(), | |
vmax=mel_spec.max(), | |
) | |
if i == 0: | |
ax.set_title("Original", fontsize=12) | |
ax.set_xlabel("Time") | |
ax.set_ylabel("Mel bins") | |
# Degraded Spectrogram | |
ax = axes[i, 1] | |
mel_spec = processed_mel[i].detach().cpu().numpy() | |
im = ax.imshow( | |
mel_spec.T, | |
aspect="auto", | |
origin="lower", | |
interpolation="nearest", | |
cmap="viridis", | |
vmin=mel_spec.min(), | |
vmax=mel_spec.max(), | |
) | |
if i == 0: | |
ax.set_title("Degraded", fontsize=12) | |
ax.set_xlabel("Time") | |
ax.set_ylabel("Mel bins") | |
# Predicted Spectrogram | |
ax = axes[i, 2] | |
mel_spec = predicted_mel[i].detach().cpu().numpy() | |
im = ax.imshow( | |
mel_spec.T, | |
aspect="auto", | |
origin="lower", | |
interpolation="nearest", | |
cmap="viridis", | |
vmin=mel_spec.min(), | |
vmax=mel_spec.max(), | |
) | |
if i == 0: | |
ax.set_title("Predicted", fontsize=12) | |
ax.set_xlabel("Time") | |
ax.set_ylabel("Mel bins") | |
# Adjust layout and add a single colorbar | |
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Make room for the suptitle | |
# Add a single colorbar to the right of the subplots | |
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # [left, bottom, width, height] | |
fig.colorbar(im, cax=cbar_ax, orientation="vertical", label="Amplitude") | |
return fig | |
def train(self, train_dataset, epochs, batch_size, num_workers=2, save_step=100): | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
collate_fn=collate_fn, | |
shuffle=True, | |
num_workers=num_workers, | |
pin_memory=True, | |
) | |
train_dataloader = self.accelerator.prepare(train_dataloader) | |
start_step = self.load_checkpoint() | |
global_step = start_step | |
for epoch in range(epochs): | |
self.model.train() | |
progress_bar = tqdm( | |
train_dataloader, | |
desc=f"Epoch {epoch+1}/{epochs}", | |
unit="step", | |
disable=not self.accelerator.is_local_main_process, | |
) | |
epoch_loss = 0.0 | |
accumulated_loss = 0.0 | |
for batch in progress_bar: | |
if batch is None: | |
global_step += 1 | |
continue # Skip the batch | |
with self.accelerator.accumulate(self.model): | |
original_mel = batch["original_mel"].to(self.device) | |
processed_mel = batch["processed_mel"].to(self.device) | |
mask = batch["mask"].to(self.device) | |
original_mel = rearrange(original_mel, "b d n -> b n d") | |
processed_mel = rearrange(processed_mel, "b d n -> b n d") | |
loss, pred = self.model(original_mel, processed_mel, mask) | |
accumulated_loss += loss.item() | |
self.accelerator.backward(loss) | |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_( | |
self.model.parameters(), self.max_grad_norm | |
) | |
if self.accelerator.sync_gradients: | |
avg_loss = accumulated_loss / self.grad_accumulation_steps | |
if self.accelerator.is_main_process: | |
logger.info(f"step {global_step}: loss = {avg_loss:.4f}") | |
self.writer.add_scalar("loss", avg_loss, global_step) | |
accumulated_loss = 0.0 | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
if self.is_main: | |
self.ema_model.update() | |
global_step += 1 | |
epoch_loss += loss.item() | |
progress_bar.set_postfix(loss=loss.item()) | |
if global_step % save_step == 0: | |
self.save_checkpoint(global_step) | |
self.writer.flush() | |
self.writer.close() | |
--- | |
def forward(self, original, processed, mask): | |
times = torch.rand((original.size(0),), device=original.device) | |
t = rearrange(times, "b -> b 1 1") | |
# Get the maximum length in the batch | |
max_length = max(original.size(1), processed.size(1)) | |
# Pad original and processed to max_length | |
original_padded = F.pad(original, (0, 0, 0, max_length - original.size(1))) | |
processed_padded = F.pad(processed, (0, 0, 0, max_length - processed.size(1))) | |
# Extend mask to max_length (0 for padded elements) | |
mask_extended = F.pad(mask, (0, max_length - mask.size(1))) | |
x0 = torch.randn_like(processed_padded) | |
x1 = original_padded | |
# Compute the noisy intermediate representation `w` using time step `t` | |
w = (1.0 - t) * x0 + t * x1 | |
# Compute the flow as the difference between the original data and `w` | |
flow = x1 - x0 | |
# Predict the vector field with transformer-based architecture | |
pred = self.transformer_with_pred_head( | |
w, times=times, cond=processed_padded, mask=mask_extended | |
) | |
# Apply mask to both pred and flow (only compute loss for non-padded elements) | |
pred = pred * mask_extended.unsqueeze(-1) | |
flow = flow * mask_extended.unsqueeze(-1) | |
# Calculate MSE loss on non-padded elements | |
loss = F.mse_loss(pred, flow, reduction="none") | |
loss = loss * mask_extended.unsqueeze(-1) | |
num_unmasked = mask_extended.sum() * pred.size(-1) | |
# Avoid division by zero | |
if num_unmasked == 0: | |
loss = torch.tensor(0.0, device=loss.device) | |
else: | |
loss = loss.sum() / num_unmasked | |
return loss, pred |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment