Skip to content

Instantly share code, notes, and snippets.

@ovshake
Created September 4, 2022 17:07
Show Gist options
  • Save ovshake/362591dcb38bd7471df2de55f715cf6e to your computer and use it in GitHub Desktop.
Save ovshake/362591dcb38bd7471df2de55f715cf6e to your computer and use it in GitHub Desktop.
from diffusers import UNet2DModel, UNet2DConditionModel
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import clip
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from dataclasses import dataclass
from accelerate import Accelerator
import os
from tqdm import tqdm
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
from glob import glob
import numpy as np
class StableDataset(Dataset):
def __init__(self, root_dir, transforms):
self.img_dir = os.path.join(root_dir, "cloth")
self.np_embedding_dir = os.path.join(root_dir, "clip_txt_embeddings")
self.transforms = transforms
self.paths = glob(os.path.join(self.img_dir, "*.jpg"))
self.names = os.listdir(self.img_dir)
self.names = [x.replace(".jpg", "") for x in self.names]
def __len__(self):
return len(self.names)
def __getitem__(self, index):
name = self.names[index]
img_path = os.path.join(self.img_dir, f"{name}.jpg")
np_embedding_path = os.path.join(self.np_embedding_dir, f"{name}.np.gz")
img = Image.open(img_path)
img = self.transforms(img)
np_embedding = np.loadtxt(np_embedding_path, dtype=np.dtype('float32'))
np_embedding = torch.from_numpy(np_embedding).unsqueeze(0)
return {"images": img, "np_embedding": np_embedding}
device = "cuda"
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 1
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision
output_dir = 'ddpm-fashion-128-v0' # the model namy locally and on the HF Hub
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
model = UNet2DConditionModel(sample_size=config.image_size,
in_channels=3,
out_channels=3,
layers_per_block=2,
cross_attention_dim=768).cuda()
stable_dataset = StableDataset("/data/dataset/VITON-hD/train/", transforms=preprocess)
train_dataloader = DataLoader(stable_dataset,
batch_size=config.train_batch_size,
shuffle=True,
num_workers=2,
pin_memory=True)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size = config.eval_batch_size,
generator=torch.manual_seed(config.seed),
)["sample"]
# Make a grid out of the images
image_grid = make_grid(images, rows=4, cols=4)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
logging_dir=os.path.join(config.output_dir, "logs")
)
if accelerator.is_main_process:
if config.push_to_hub:
repo = init_git_repo(config, at_init=True)
accelerator.init_trackers("train_example")
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
global_step = 0
model = model.half()
# Now you train the model
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch['images'].cuda()
embeddings = batch["np_embedding"].cuda()
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
with accelerator.accumulate(model) and torch.autocast(device_type='cuda') and torch.no_grad():
# Predict the noise residual
noise_pred = model(noisy_images.half(), encoder_hidden_states=embeddings.half(), timestep=timesteps)["sample"]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if config.push_to_hub:
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
else:
pipeline.save_pretrained(config.output_dir)
if __name__ == '__main__':
train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment