Created
January 14, 2023 21:11
-
-
Save zaptrem/7be75641411937e61b35dc05a928682d to your computer and use it in GitHub Desktop.
train_diff_mae
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from audio_diffusion_pytorch.unets import ( | |
UNetV0, | |
LTPlugin, | |
) # pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly # v0.0.2 | |
from audio_diffusion_pytorch.models import DiffusionAE | |
from a_unet import TextConditioningPlugin, NumberEmbedder | |
import os | |
import math | |
import torch | |
import torch._dynamo as dynamo | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
from torch.utils.data import DataLoader | |
import wandb | |
from audio_data_pytorch.utils import fractional_random_split | |
from audio_data_pytorch import MetaDataset, AllTransform | |
from audio_encoders_pytorch import Encoder1d, ME1d, TanhBottleneck | |
from einops import rearrange | |
from tqdm import tqdm | |
import warnings | |
import torch.optim.lr_scheduler as lr_scheduler | |
from archisound import ArchiSound | |
from ema_pytorch import EMA | |
# Define the warmup function | |
def warmup_fn(step): | |
if step < 5000: | |
return step / 5000 | |
return 1 | |
class DMAE1d(nn.Module): | |
def __init__(self): | |
super().__init__() | |
UNet = LTPlugin( | |
UNetV0, | |
num_filters=128, | |
window_length=64, | |
stride=64, | |
) | |
self.model = DiffusionAE( | |
net_t=UNet, | |
dim=1, | |
in_channels=2, | |
channels=[256, 512, 512, 512, 1024, 1024, 1024], | |
factors=[1, 2, 2, 2, 2, 2, 2], | |
linear_attentions=[1, 1, 1, 1, 1, 1, 1], | |
attention_features=64, | |
attention_heads=8, | |
items=[1, 2, 2, 2, 2, 2, 2], | |
encoder=ME1d( | |
in_channels=2, | |
channels=512, | |
multipliers=[1, 1, 1], | |
factors=[2, 2], | |
num_blocks=[4, 8], | |
stft_num_fft=1023, | |
stft_hop_length=256, | |
out_channels=32, | |
bottleneck=TanhBottleneck() | |
), | |
inject_depth=4 | |
) | |
def forward(self, *args, **kwargs): | |
return self.model(*args, **kwargs) | |
def encode(self, *args, **kwargs): | |
return self.model.encode(*args, **kwargs) | |
@torch.no_grad() | |
def decode(self, *args, **kwargs): | |
return self.model.decode(*args, **kwargs) | |
def log_wandb_audio_batch( | |
id: str, samples: Tensor, artists: Tensor, genres: Tensor, caption: str = "" | |
): | |
num_items = samples.shape[0] | |
samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy() | |
wandb.log( | |
{ | |
f"sample_{idx}_{id}": wandb.Audio( | |
samples[idx], | |
caption=f"{caption} Artist={list(dataset.mappings['artists'].inverse[artist_id] for artist_id in artists[idx].tolist() if artist_id != 0)} Genre={list(dataset.mappings['genres'].inverse[genre_id] for genre_id in genres[idx].tolist() if genre_id != 0)}" | |
if artists is not None and genres is not None | |
else caption, | |
sample_rate=sampling_rate, | |
) | |
for idx in range(num_items) | |
} | |
) | |
def log_val_loss(model, val_dataloader): | |
# Set the model to evaluation mode | |
model.eval() | |
torch.cuda.empty_cache() | |
# Initialize the running loss for the epoch | |
val_running_loss = 0.0 | |
# Iterate over the validation data | |
for val_data, val_labels in val_dataloader: | |
# Move the data and labels to the device | |
val_data, val_labels = val_data.to("cuda"), val_labels.to("cuda") | |
# print("val_labels", val_labels) | |
artists = val_labels[:, 0] | |
genres = val_labels[:, 1] | |
# Forward pass | |
# val_loss = model(val_data) | |
torch.cuda.empty_cache() | |
# Get start diffusion noise | |
# noise = torch.randn( | |
# (3, 2, length), device="cuda" | |
# ) | |
# samples = model.sample( | |
# noise=noise, | |
# num_steps=25, | |
# ) | |
log_wandb_audio_batch( | |
id="true", | |
samples=val_data, | |
artists=artists, | |
genres=genres, | |
) | |
with torch.amp.autocast(device_type="cuda", dtype=torch.float16): | |
with torch.no_grad(): | |
x = ema.ema_model.encode(val_data) | |
x = ema.ema_model.decode(x, num_steps=25) | |
log_wandb_audio_batch(id="recon", samples=x, artists=artists, genres=genres) | |
# Add the loss for the batch to the running loss for the epoch | |
# val_running_loss += val_loss.item() | |
break | |
# Log the loss from the validation set | |
# wandb.log({"val_loss": val_running_loss / len(val_dataloader)}) | |
# Set the model back to training mode | |
model.train() | |
warnings.filterwarnings("ignore", ".*clipped samples in output*") | |
print("cuda available: ", torch.cuda.is_available()) | |
dynamo.config.verbose = True | |
if torch.cuda.is_available(): | |
print("cudnn version is: ", torch.backends.cudnn.version()) | |
torch.set_float32_matmul_precision("medium") | |
print("cudnn benchmark was ", torch.backends.cudnn.benchmark) | |
torch.backends.cudnn.benchmark = True | |
print("matmul allow_tf32 was ", torch.backends.cuda.matmul.allow_tf32) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
print("cudnn allow_tf32 was ", torch.backends.cudnn.allow_tf32) | |
torch.backends.cudnn.allow_tf32 = True | |
while True: | |
try: | |
torch.cuda.empty_cache() | |
sampling_rate = 44100 | |
length = 131072 * 4 # 131072 * 4 # - 128*3 | |
channels = 2 | |
dataset = MetaDataset( | |
path="/mnt/wsl/PhysicalDrive2/Playlist", | |
metadata_mapping_path="mappings.json", | |
recursive=True, | |
sample_rate=sampling_rate, | |
random_crop_size=length, | |
transforms=AllTransform(stereo=True), | |
) | |
split = [1.0 - 0.01, 0.01] | |
train_dataset, val_dataset = fractional_random_split(dataset, split) | |
# Set the device to use for training | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load 1st stage AE | |
model = DMAE1d() | |
model.train() | |
# Initialize the model and move it to the device | |
model = model.to(device) | |
# Optimize model using Torch Dynamo | |
#model = torch.compile(model) | |
ema = EMA( | |
model, | |
beta=0.995, | |
power=0.7 | |
) | |
# Set the optimizer and loss function | |
optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=1e-4, | |
betas=(0.95, 0.999), | |
eps=1e-6, | |
weight_decay=1e-3, | |
) | |
# Create the mixed-precision gradient scaler. | |
scaler = torch.cuda.amp.GradScaler() | |
# Create the learning rate scheduler | |
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) | |
# Set the directory to save checkpoints | |
checkpoint_dir = "checkpoints" | |
# Create the checkpoint directory if it doesn't exist | |
if not os.path.exists(checkpoint_dir): | |
os.makedirs(checkpoint_dir) | |
# Path to the latest checkpoint | |
latest_checkpoint_path = os.path.join( | |
checkpoint_dir, "maefinetune_1_latest_checkpoint.pth" | |
) | |
if os.path.exists(latest_checkpoint_path): | |
# Load the latest checkpoint | |
checkpoint = torch.load(latest_checkpoint_path) | |
# Load the model and optimizer state from the checkpoint | |
model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
scaler.load_state_dict(checkpoint["scaler_state_dict"]) | |
ema.load_state_dict(checkpoint["ema_state_dict"]) | |
# Set the starting epoch to the epoch from the checkpoint | |
epoch = checkpoint["epoch"] + 1 | |
step = checkpoint["step"] | |
else: | |
# Set the starting epoch to 0 if there is no checkpoint | |
epoch = 0 | |
step = 0 | |
# Initialize WandB maefinetune-1 | |
wandb.init(project="maefinetune-3", resume=False) | |
# wandb.init(project="sanity-check") | |
# Create a PyTorch data loader for the train/val sets | |
train_dataloader = DataLoader( | |
train_dataset, batch_size=12, shuffle=True, pin_memory=True, num_workers=14, prefetch_factor=2 | |
) | |
val_dataloader = DataLoader( | |
val_dataset, batch_size=3, shuffle=True, pin_memory=True, num_workers=4 | |
) | |
# Gradient accumulation steps | |
accumulation_steps = 1 | |
# Load the next batch while processing the current one | |
async_load = False | |
# Save every X epochs | |
save_every_epochs = 1 | |
# Initialize the running loss | |
running_loss = 0.0 | |
# Training loop | |
while True: | |
print("epoch: ", epoch) | |
torch.cuda.empty_cache() | |
# Set the model to training mode | |
model.train() | |
# Loss accumulation | |
#loss = None | |
# data_iter = iter(train_dataloader) | |
# next_batch = data_iter.__next__() # start loading the first batch | |
# next_batch = [ _.cuda(non_blocking=True) for _ in next_batch ] # with pin_memory=True and non_blocking=True, this will copy data to GPU non blockingly | |
# Wrap the for loop in a tqdm progress bar | |
with tqdm( | |
total=len(train_dataloader), | |
desc=f"Epoch {epoch}", | |
unit="it", | |
leave=False, | |
) as pbar: | |
for data, labels in train_dataloader: | |
# for i in range(len(train_dataloader) - 1): | |
# data, labels = next_batch | |
# if i + 2 != len(train_dataloader) - 1: | |
# # start copying data of next batch | |
# next_batch = data_iter.__next__() | |
# next_batch = [ _.cuda(non_blocking=True) for _ in next_batch] | |
#testing | |
# data = torch.zeros((3, 32, 4096), device=device) | |
# labels = torch.zeros((3, 2, 4), device=device, dtype=int) | |
# Zero out the gradients from the previous iteration | |
optimizer.zero_grad(set_to_none=True) | |
data, labels = data.to(device, non_blocking=True), labels.to( | |
device, non_blocking=True | |
) | |
with torch.amp.autocast(device_type="cuda", dtype=torch.float16): | |
# Forward pass | |
loss = model(data, sdstft_loss=False) | |
# If we have reached an accumulation step, update the weights | |
if (step + 1) % accumulation_steps == 0: | |
#print("loss: ", loss) | |
numerical_loss = loss.item() | |
# step > 10000 | |
if step > 10000 and ( | |
math.isnan(numerical_loss) or numerical_loss > 0.5 | |
): | |
print( | |
"Loss went NaN/crazy! Skipping example. Loss:", | |
numerical_loss, | |
) | |
del data, labels, loss | |
if not math.isnan(numerical_loss): | |
running_loss += numerical_loss | |
torch.cuda.empty_cache() | |
else: | |
# Backward pass and optimization step | |
# Scales the loss, and calls backward() | |
# to create scaled gradients | |
loss = loss # / accumulation_steps | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
# Updates the scale for next iteration | |
scaler.update() | |
# Updates the EMA model | |
ema.update() | |
# Add the loss for the batch to the running loss for the epoch | |
pbar.set_postfix_str( | |
f"loss: {numerical_loss:.4f}, step: {step}" | |
) | |
pbar.update() | |
running_loss += numerical_loss | |
# Log the loss to Weights & Biases | |
wandb.log( | |
{"loss": numerical_loss, "step": step, "epoch": epoch} | |
) | |
# Reset loss accumulation | |
del loss | |
# Increment the step count | |
step += 1 | |
# Weird hack that saves VRAM at ends of epochs. | |
if step % (len(train_dataloader) - 1) == 0: | |
break | |
# Log the loss from the validation set every 2500 steps | |
if step % 1100 == 0: | |
print("validation") | |
del data, labels | |
if async_load: | |
next_batch[0].cpu() | |
log_val_loss(model, val_dataloader) | |
if async_load: | |
next_batch[0].to(device, non_blocking=True) | |
# Save checkpoint | |
if step % 2600 == 0: | |
print("attempting checkpoint") | |
assert not math.isnan(numerical_loss) | |
wandb.log( | |
{ | |
"running_loss": running_loss / 2600, | |
"step": step, | |
"epoch": epoch, | |
} | |
) | |
running_loss = 0 | |
# Save a checkpoint of the model to a temporary file | |
temp_checkpoint_path = os.path.join( | |
checkpoint_dir, "maefinetune-1_temp_checkpoint.pth" | |
) | |
torch.save( | |
{ | |
"epoch": epoch, | |
"step": step, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"scaler_state_dict": scaler.state_dict(), | |
"ema_state_dict": ema.state_dict(), | |
}, | |
temp_checkpoint_path, | |
) | |
# Check if the latest checkpoint file exists | |
if os.path.exists(latest_checkpoint_path): | |
# Get the file sizes of the two files | |
temp_checkpoint_size = os.stat(temp_checkpoint_path).st_size | |
latest_checkpoint_size = os.stat(latest_checkpoint_path).st_size | |
assert temp_checkpoint_size == latest_checkpoint_size | |
try: | |
# Delete the old checkpoint | |
os.remove(latest_checkpoint_path) | |
except: | |
print("No checkpoint file to delete") | |
# Rename the temporary checkpoint | |
os.rename(temp_checkpoint_path, latest_checkpoint_path) | |
# Print the epoch loss | |
# print(f"Epoch {epoch} loss: {running_loss / len(train_dataloader)}") | |
# model.cpu() | |
# torch.cuda.empty_cache() | |
epoch += 1 | |
except Exception as e: | |
print("error type:", type(e)) | |
# Only recover from Out of Memory Erros | |
if True and type(e) != torch.cuda.OutOfMemoryError: | |
raise e | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment