Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created July 12, 2023 06:27
Show Gist options
  • Save cloneofsimo/73db5aaa6d74b1b3eebce31333083afa to your computer and use it in GitHub Desktop.
Save cloneofsimo/73db5aaa6d74b1b3eebce31333083afa to your computer and use it in GitHub Desktop.
init
# Bootstrapped from Huggingface diffuser's code.
import gc
import math
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from typing import Optional
import os
import warnings
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "CLIPTextModelWithProjection":
from transformers import CLIPTextModelWithProjection
return CLIPTextModelWithProjection
else:
raise ValueError(f"{model_class} is not supported.")
class Text2ImageDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
size=1024,
center_crop=False,
instance_prompt_hidden_states=None,
instance_unet_added_conditions=None,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt_hidden_states = instance_prompt_hidden_states
self.instance_unet_added_conditions = instance_unet_added_conditions
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.instance_prompt_hidden_states
example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions
return example
def collate_fn(examples):
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
add_text_embeds = [
example["instance_added_cond_kwargs"]["text_embeds"] for example in examples
]
add_time_ids = [
example["instance_added_cond_kwargs"]["time_ids"] for example in examples
]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
add_text_embeds = torch.cat(add_text_embeds, dim=0)
add_time_ids = torch.cat(add_time_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"unet_added_conditions": {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids,
},
}
if has_attention_mask:
batch["attention_mask"] = attention_mask
return batch
def encode_prompt(text_encoders, tokenizers, prompt):
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
def main(
pretrained_model_name_or_path: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-0.9",
revision: Optional[str] = None,
instance_data_dir: Optional[str] = "dataset",
instance_prompt: Optional[str] = "a 3d render of bhks ninja",
validation_prompt: Optional[str] = None,
num_validation_images: int = 4,
validation_epochs: int = 40,
output_dir: str = "ft",
seed: Optional[int] = 42,
resolution: int = 512,
crops_coords_top_left_h: int = 0,
crops_coords_top_left_w: int = 0,
center_crop: bool = False,
train_text_encoder: bool = False,
train_batch_size: int = 2,
sample_batch_size: int = 4,
num_train_epochs: int = 200,
max_train_steps: Optional[int] = None,
checkpointing_steps: int = 500,
gradient_accumulation_steps: int = 1,
gradient_checkpointing: bool = False,
learning_rate: float = 1e-5,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 500,
lr_num_cycles: int = 1,
lr_power: float = 1.0,
dataloader_num_workers: int = 0,
max_grad_norm: float = 1.0,
allow_tf32: bool = False,
mixed_precision: Optional[str] = "bf16",
device="cuda:0",
tracker="wandb",
) -> None:
validation_prompt = validation_prompt or instance_prompt
if tracker == "wandb":
import wandb
wandb.init(project="diffusion", name=output_dir.split("/")[-1])
if train_text_encoder:
raise NotImplementedError("Text encoder training not yet supported.")
tokenizer_one = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
revision=revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=revision,
use_fast=False,
)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
pretrained_model_name_or_path, revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler"
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", revision=revision
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=revision
)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
weight_dtype = torch.float32
if mixed_precision == "fp16":
weight_dtype = torch.float16
elif mixed_precision == "bf16":
weight_dtype = torch.bfloat16
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=torch.float32)
text_encoder_one.to(device, dtype=weight_dtype)
text_encoder_two.to(device, dtype=weight_dtype)
if scale_lr:
learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size
optimizer_class = torch.optim.AdamW
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
@torch.no_grad()
def compute_embeddings(prompt, text_encoders, tokenizers):
original_size = (resolution, resolution)
target_size = (resolution, resolution)
crops_coords_top_left = (crops_coords_top_left_h, crops_coords_top_left_w)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt
)
add_text_embeds = pooled_prompt_embeds
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype)
unet_added_cond_kwargs = {
"text_embeds": add_text_embeds * 0.0,
"time_ids": add_time_ids,
}
return prompt_embeds, unet_added_cond_kwargs
instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings(
instance_prompt, text_encoders, tokenizers
)
del tokenizers, text_encoders
# prompt_aux_param = nn.Parameter(torch.zeros_like(instance_prompt_hidden_states))
# prompt_aux_param.requires_grad_(True)
# prompt_aux_param.to(device)
prompt_aux_param = 0.0
param_to_optimize = []
# fine tune only attn weights
for name, param in unet.named_parameters():
if "transformer" in name and "weight" in name:
param.requires_grad_(True)
param_to_optimize.append(param)
print(name)
else:
param.requires_grad_(False)
# Optimizer creation
params_to_optimize = [
# {
# "params": prompt_aux_param,
# "lr" : 4e-4,
# "weight_decay": 0.1,
# },
{
"params": param_to_optimize,
"lr": learning_rate,
},
]
adam_beta1: float = 0.9
adam_beta2: float = 0.999
adam_weight_decay: float = 1e-2
adam_epsilon: float = 1e-8
optimizer = optimizer_class(
params_to_optimize,
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
)
gc.collect()
torch.cuda.empty_cache()
train_dataset = Text2ImageDataset(
instance_data_root=instance_data_dir,
instance_prompt=instance_prompt,
size=resolution,
center_crop=center_crop,
instance_prompt_hidden_states=instance_prompt_hidden_states,
instance_unet_added_conditions=instance_unet_added_conditions,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples),
num_workers=dataloader_num_workers,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if max_train_steps is None:
max_train_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
num_cycles=lr_num_cycles,
power=lr_power,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
total_batch_size = train_batch_size * gradient_accumulation_steps
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num batches each epoch = {len(train_dataloader)}")
print(f" Num Epochs = {num_train_epochs}")
print(f" Instantaneous batch size per device = {train_batch_size}")
print(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
print(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f" Total optimization steps = {max_train_steps}")
global_step = 0
first_epoch = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps))
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
progress_bar.update(1)
global_step += 1
print(f"step: {global_step}, epoch: {epoch}")
# Convert images to latent space
model_input = vae.encode(
batch["pixel_values"].to(device)
).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = model_input.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
device=model_input.device,
)
timesteps = timesteps.long()
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_pred = unet(
noisy_model_input,
timesteps,
batch["input_ids"] + prompt_aux_param,
added_cond_kwargs=batch["unet_added_conditions"],
).sample
target = noise
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if global_step % checkpointing_steps == 0:
unet.save_pretrained(f"{output_dir}/unet/checkpoint-{global_step}")
if epoch % validation_epochs == 0:
with torch.no_grad():
pipeline = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
unet=unet,
torch_dtype=weight_dtype,
safety_checker=None,
revision=revision,
)
pipeline = pipeline.to(device)
pipeline.set_progress_bar_config(disable=True)
print(prompt_aux_param)
# run inference
generator = (
torch.Generator(device=device).manual_seed(seed) if seed else None
)
pipeline_args = {
"prompt_embeds": instance_prompt_hidden_states + prompt_aux_param,
"pooled_prompt_embeds": instance_unet_added_conditions[
"text_embeds"
],
"height": 1024,
"width": 1024,
}
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(num_validation_images)
]
if tracker == "wandb":
import wandb
wandb.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment