Skip to content

Instantly share code, notes, and snippets.

@zaptrem
Created January 14, 2023 21:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zaptrem/7be75641411937e61b35dc05a928682d to your computer and use it in GitHub Desktop.
Save zaptrem/7be75641411937e61b35dc05a928682d to your computer and use it in GitHub Desktop.
train_diff_mae
import torch
from audio_diffusion_pytorch.unets import (
UNetV0,
LTPlugin,
) # pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly # v0.0.2
from audio_diffusion_pytorch.models import DiffusionAE
from a_unet import TextConditioningPlugin, NumberEmbedder
import os
import math
import torch
import torch._dynamo as dynamo
import torch.nn.functional as F
from torch import Tensor, nn
from torch.utils.data import DataLoader
import wandb
from audio_data_pytorch.utils import fractional_random_split
from audio_data_pytorch import MetaDataset, AllTransform
from audio_encoders_pytorch import Encoder1d, ME1d, TanhBottleneck
from einops import rearrange
from tqdm import tqdm
import warnings
import torch.optim.lr_scheduler as lr_scheduler
from archisound import ArchiSound
from ema_pytorch import EMA
# Define the warmup function
def warmup_fn(step):
if step < 5000:
return step / 5000
return 1
class DMAE1d(nn.Module):
def __init__(self):
super().__init__()
UNet = LTPlugin(
UNetV0,
num_filters=128,
window_length=64,
stride=64,
)
self.model = DiffusionAE(
net_t=UNet,
dim=1,
in_channels=2,
channels=[256, 512, 512, 512, 1024, 1024, 1024],
factors=[1, 2, 2, 2, 2, 2, 2],
linear_attentions=[1, 1, 1, 1, 1, 1, 1],
attention_features=64,
attention_heads=8,
items=[1, 2, 2, 2, 2, 2, 2],
encoder=ME1d(
in_channels=2,
channels=512,
multipliers=[1, 1, 1],
factors=[2, 2],
num_blocks=[4, 8],
stft_num_fft=1023,
stft_hop_length=256,
out_channels=32,
bottleneck=TanhBottleneck()
),
inject_depth=4
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
@torch.no_grad()
def decode(self, *args, **kwargs):
return self.model.decode(*args, **kwargs)
def log_wandb_audio_batch(
id: str, samples: Tensor, artists: Tensor, genres: Tensor, caption: str = ""
):
num_items = samples.shape[0]
samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy()
wandb.log(
{
f"sample_{idx}_{id}": wandb.Audio(
samples[idx],
caption=f"{caption} Artist={list(dataset.mappings['artists'].inverse[artist_id] for artist_id in artists[idx].tolist() if artist_id != 0)} Genre={list(dataset.mappings['genres'].inverse[genre_id] for genre_id in genres[idx].tolist() if genre_id != 0)}"
if artists is not None and genres is not None
else caption,
sample_rate=sampling_rate,
)
for idx in range(num_items)
}
)
def log_val_loss(model, val_dataloader):
# Set the model to evaluation mode
model.eval()
torch.cuda.empty_cache()
# Initialize the running loss for the epoch
val_running_loss = 0.0
# Iterate over the validation data
for val_data, val_labels in val_dataloader:
# Move the data and labels to the device
val_data, val_labels = val_data.to("cuda"), val_labels.to("cuda")
# print("val_labels", val_labels)
artists = val_labels[:, 0]
genres = val_labels[:, 1]
# Forward pass
# val_loss = model(val_data)
torch.cuda.empty_cache()
# Get start diffusion noise
# noise = torch.randn(
# (3, 2, length), device="cuda"
# )
# samples = model.sample(
# noise=noise,
# num_steps=25,
# )
log_wandb_audio_batch(
id="true",
samples=val_data,
artists=artists,
genres=genres,
)
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
with torch.no_grad():
x = ema.ema_model.encode(val_data)
x = ema.ema_model.decode(x, num_steps=25)
log_wandb_audio_batch(id="recon", samples=x, artists=artists, genres=genres)
# Add the loss for the batch to the running loss for the epoch
# val_running_loss += val_loss.item()
break
# Log the loss from the validation set
# wandb.log({"val_loss": val_running_loss / len(val_dataloader)})
# Set the model back to training mode
model.train()
warnings.filterwarnings("ignore", ".*clipped samples in output*")
print("cuda available: ", torch.cuda.is_available())
dynamo.config.verbose = True
if torch.cuda.is_available():
print("cudnn version is: ", torch.backends.cudnn.version())
torch.set_float32_matmul_precision("medium")
print("cudnn benchmark was ", torch.backends.cudnn.benchmark)
torch.backends.cudnn.benchmark = True
print("matmul allow_tf32 was ", torch.backends.cuda.matmul.allow_tf32)
torch.backends.cuda.matmul.allow_tf32 = True
print("cudnn allow_tf32 was ", torch.backends.cudnn.allow_tf32)
torch.backends.cudnn.allow_tf32 = True
while True:
try:
torch.cuda.empty_cache()
sampling_rate = 44100
length = 131072 * 4 # 131072 * 4 # - 128*3
channels = 2
dataset = MetaDataset(
path="/mnt/wsl/PhysicalDrive2/Playlist",
metadata_mapping_path="mappings.json",
recursive=True,
sample_rate=sampling_rate,
random_crop_size=length,
transforms=AllTransform(stereo=True),
)
split = [1.0 - 0.01, 0.01]
train_dataset, val_dataset = fractional_random_split(dataset, split)
# Set the device to use for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load 1st stage AE
model = DMAE1d()
model.train()
# Initialize the model and move it to the device
model = model.to(device)
# Optimize model using Torch Dynamo
#model = torch.compile(model)
ema = EMA(
model,
beta=0.995,
power=0.7
)
# Set the optimizer and loss function
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
betas=(0.95, 0.999),
eps=1e-6,
weight_decay=1e-3,
)
# Create the mixed-precision gradient scaler.
scaler = torch.cuda.amp.GradScaler()
# Create the learning rate scheduler
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn)
# Set the directory to save checkpoints
checkpoint_dir = "checkpoints"
# Create the checkpoint directory if it doesn't exist
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Path to the latest checkpoint
latest_checkpoint_path = os.path.join(
checkpoint_dir, "maefinetune_1_latest_checkpoint.pth"
)
if os.path.exists(latest_checkpoint_path):
# Load the latest checkpoint
checkpoint = torch.load(latest_checkpoint_path)
# Load the model and optimizer state from the checkpoint
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scaler.load_state_dict(checkpoint["scaler_state_dict"])
ema.load_state_dict(checkpoint["ema_state_dict"])
# Set the starting epoch to the epoch from the checkpoint
epoch = checkpoint["epoch"] + 1
step = checkpoint["step"]
else:
# Set the starting epoch to 0 if there is no checkpoint
epoch = 0
step = 0
# Initialize WandB maefinetune-1
wandb.init(project="maefinetune-3", resume=False)
# wandb.init(project="sanity-check")
# Create a PyTorch data loader for the train/val sets
train_dataloader = DataLoader(
train_dataset, batch_size=12, shuffle=True, pin_memory=True, num_workers=14, prefetch_factor=2
)
val_dataloader = DataLoader(
val_dataset, batch_size=3, shuffle=True, pin_memory=True, num_workers=4
)
# Gradient accumulation steps
accumulation_steps = 1
# Load the next batch while processing the current one
async_load = False
# Save every X epochs
save_every_epochs = 1
# Initialize the running loss
running_loss = 0.0
# Training loop
while True:
print("epoch: ", epoch)
torch.cuda.empty_cache()
# Set the model to training mode
model.train()
# Loss accumulation
#loss = None
# data_iter = iter(train_dataloader)
# next_batch = data_iter.__next__() # start loading the first batch
# next_batch = [ _.cuda(non_blocking=True) for _ in next_batch ] # with pin_memory=True and non_blocking=True, this will copy data to GPU non blockingly
# Wrap the for loop in a tqdm progress bar
with tqdm(
total=len(train_dataloader),
desc=f"Epoch {epoch}",
unit="it",
leave=False,
) as pbar:
for data, labels in train_dataloader:
# for i in range(len(train_dataloader) - 1):
# data, labels = next_batch
# if i + 2 != len(train_dataloader) - 1:
# # start copying data of next batch
# next_batch = data_iter.__next__()
# next_batch = [ _.cuda(non_blocking=True) for _ in next_batch]
#testing
# data = torch.zeros((3, 32, 4096), device=device)
# labels = torch.zeros((3, 2, 4), device=device, dtype=int)
# Zero out the gradients from the previous iteration
optimizer.zero_grad(set_to_none=True)
data, labels = data.to(device, non_blocking=True), labels.to(
device, non_blocking=True
)
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
# Forward pass
loss = model(data, sdstft_loss=False)
# If we have reached an accumulation step, update the weights
if (step + 1) % accumulation_steps == 0:
#print("loss: ", loss)
numerical_loss = loss.item()
# step > 10000
if step > 10000 and (
math.isnan(numerical_loss) or numerical_loss > 0.5
):
print(
"Loss went NaN/crazy! Skipping example. Loss:",
numerical_loss,
)
del data, labels, loss
if not math.isnan(numerical_loss):
running_loss += numerical_loss
torch.cuda.empty_cache()
else:
# Backward pass and optimization step
# Scales the loss, and calls backward()
# to create scaled gradients
loss = loss # / accumulation_steps
scaler.scale(loss).backward()
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
# Updates the EMA model
ema.update()
# Add the loss for the batch to the running loss for the epoch
pbar.set_postfix_str(
f"loss: {numerical_loss:.4f}, step: {step}"
)
pbar.update()
running_loss += numerical_loss
# Log the loss to Weights & Biases
wandb.log(
{"loss": numerical_loss, "step": step, "epoch": epoch}
)
# Reset loss accumulation
del loss
# Increment the step count
step += 1
# Weird hack that saves VRAM at ends of epochs.
if step % (len(train_dataloader) - 1) == 0:
break
# Log the loss from the validation set every 2500 steps
if step % 1100 == 0:
print("validation")
del data, labels
if async_load:
next_batch[0].cpu()
log_val_loss(model, val_dataloader)
if async_load:
next_batch[0].to(device, non_blocking=True)
# Save checkpoint
if step % 2600 == 0:
print("attempting checkpoint")
assert not math.isnan(numerical_loss)
wandb.log(
{
"running_loss": running_loss / 2600,
"step": step,
"epoch": epoch,
}
)
running_loss = 0
# Save a checkpoint of the model to a temporary file
temp_checkpoint_path = os.path.join(
checkpoint_dir, "maefinetune-1_temp_checkpoint.pth"
)
torch.save(
{
"epoch": epoch,
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"ema_state_dict": ema.state_dict(),
},
temp_checkpoint_path,
)
# Check if the latest checkpoint file exists
if os.path.exists(latest_checkpoint_path):
# Get the file sizes of the two files
temp_checkpoint_size = os.stat(temp_checkpoint_path).st_size
latest_checkpoint_size = os.stat(latest_checkpoint_path).st_size
assert temp_checkpoint_size == latest_checkpoint_size
try:
# Delete the old checkpoint
os.remove(latest_checkpoint_path)
except:
print("No checkpoint file to delete")
# Rename the temporary checkpoint
os.rename(temp_checkpoint_path, latest_checkpoint_path)
# Print the epoch loss
# print(f"Epoch {epoch} loss: {running_loss / len(train_dataloader)}")
# model.cpu()
# torch.cuda.empty_cache()
epoch += 1
except Exception as e:
print("error type:", type(e))
# Only recover from Out of Memory Erros
if True and type(e) != torch.cuda.OutOfMemoryError:
raise e
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment