Created
July 12, 2023 06:27
-
-
Save cloneofsimo/73db5aaa6d74b1b3eebce31333083afa to your computer and use it in GitHub Desktop.
init
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Bootstrapped from Huggingface diffuser's code. | |
import gc | |
import math | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from PIL import Image | |
from PIL.ImageOps import exif_transpose | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer, PretrainedConfig | |
from diffusers import ( | |
AutoencoderKL, | |
DDPMScheduler, | |
DiffusionPipeline, | |
UNet2DConditionModel, | |
) | |
from diffusers.optimization import get_scheduler | |
from typing import Optional | |
import os | |
import warnings | |
def import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" | |
): | |
text_encoder_config = PretrainedConfig.from_pretrained( | |
pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
) | |
model_class = text_encoder_config.architectures[0] | |
if model_class == "CLIPTextModel": | |
from transformers import CLIPTextModel | |
return CLIPTextModel | |
elif model_class == "CLIPTextModelWithProjection": | |
from transformers import CLIPTextModelWithProjection | |
return CLIPTextModelWithProjection | |
else: | |
raise ValueError(f"{model_class} is not supported.") | |
class Text2ImageDataset(Dataset): | |
def __init__( | |
self, | |
instance_data_root, | |
instance_prompt, | |
size=1024, | |
center_crop=False, | |
instance_prompt_hidden_states=None, | |
instance_unet_added_conditions=None, | |
): | |
self.size = size | |
self.center_crop = center_crop | |
self.instance_prompt_hidden_states = instance_prompt_hidden_states | |
self.instance_unet_added_conditions = instance_unet_added_conditions | |
self.instance_data_root = Path(instance_data_root) | |
if not self.instance_data_root.exists(): | |
raise ValueError("Instance images root doesn't exists.") | |
self.instance_images_path = list(Path(instance_data_root).iterdir()) | |
self.num_instance_images = len(self.instance_images_path) | |
self.instance_prompt = instance_prompt | |
self._length = self.num_instance_images | |
self.image_transforms = transforms.Compose( | |
[ | |
transforms.Resize( | |
size, interpolation=transforms.InterpolationMode.BILINEAR | |
), | |
transforms.CenterCrop(size) | |
if center_crop | |
else transforms.RandomCrop(size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
def __len__(self): | |
return self._length | |
def __getitem__(self, index): | |
example = {} | |
instance_image = Image.open( | |
self.instance_images_path[index % self.num_instance_images] | |
) | |
instance_image = exif_transpose(instance_image) | |
if not instance_image.mode == "RGB": | |
instance_image = instance_image.convert("RGB") | |
example["instance_images"] = self.image_transforms(instance_image) | |
example["instance_prompt_ids"] = self.instance_prompt_hidden_states | |
example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions | |
return example | |
def collate_fn(examples): | |
has_attention_mask = "instance_attention_mask" in examples[0] | |
input_ids = [example["instance_prompt_ids"] for example in examples] | |
pixel_values = [example["instance_images"] for example in examples] | |
add_text_embeds = [ | |
example["instance_added_cond_kwargs"]["text_embeds"] for example in examples | |
] | |
add_time_ids = [ | |
example["instance_added_cond_kwargs"]["time_ids"] for example in examples | |
] | |
if has_attention_mask: | |
attention_mask = [example["instance_attention_mask"] for example in examples] | |
pixel_values = torch.stack(pixel_values) | |
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
input_ids = torch.cat(input_ids, dim=0) | |
add_text_embeds = torch.cat(add_text_embeds, dim=0) | |
add_time_ids = torch.cat(add_time_ids, dim=0) | |
batch = { | |
"input_ids": input_ids, | |
"pixel_values": pixel_values, | |
"unet_added_conditions": { | |
"text_embeds": add_text_embeds, | |
"time_ids": add_time_ids, | |
}, | |
} | |
if has_attention_mask: | |
batch["attention_mask"] = attention_mask | |
return batch | |
def encode_prompt(text_encoders, tokenizers, prompt): | |
prompt_embeds_list = [] | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder( | |
text_input_ids.to(text_encoder.device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
return prompt_embeds, pooled_prompt_embeds | |
def main( | |
pretrained_model_name_or_path: Optional[ | |
str | |
] = "stabilityai/stable-diffusion-xl-base-0.9", | |
revision: Optional[str] = None, | |
instance_data_dir: Optional[str] = "dataset", | |
instance_prompt: Optional[str] = "a 3d render of bhks ninja", | |
validation_prompt: Optional[str] = None, | |
num_validation_images: int = 4, | |
validation_epochs: int = 40, | |
output_dir: str = "ft", | |
seed: Optional[int] = 42, | |
resolution: int = 512, | |
crops_coords_top_left_h: int = 0, | |
crops_coords_top_left_w: int = 0, | |
center_crop: bool = False, | |
train_text_encoder: bool = False, | |
train_batch_size: int = 2, | |
sample_batch_size: int = 4, | |
num_train_epochs: int = 200, | |
max_train_steps: Optional[int] = None, | |
checkpointing_steps: int = 500, | |
gradient_accumulation_steps: int = 1, | |
gradient_checkpointing: bool = False, | |
learning_rate: float = 1e-5, | |
scale_lr: bool = False, | |
lr_scheduler: str = "constant", | |
lr_warmup_steps: int = 500, | |
lr_num_cycles: int = 1, | |
lr_power: float = 1.0, | |
dataloader_num_workers: int = 0, | |
max_grad_norm: float = 1.0, | |
allow_tf32: bool = False, | |
mixed_precision: Optional[str] = "bf16", | |
device="cuda:0", | |
tracker="wandb", | |
) -> None: | |
validation_prompt = validation_prompt or instance_prompt | |
if tracker == "wandb": | |
import wandb | |
wandb.init(project="diffusion", name=output_dir.split("/")[-1]) | |
if train_text_encoder: | |
raise NotImplementedError("Text encoder training not yet supported.") | |
tokenizer_one = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=revision, | |
use_fast=False, | |
) | |
tokenizer_two = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer_2", | |
revision=revision, | |
use_fast=False, | |
) | |
# import correct text encoder classes | |
text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path, revision | |
) | |
text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path, revision, subfolder="text_encoder_2" | |
) | |
# Load scheduler and models | |
noise_scheduler = DDPMScheduler.from_pretrained( | |
pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
text_encoder_one = text_encoder_cls_one.from_pretrained( | |
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision | |
) | |
text_encoder_two = text_encoder_cls_two.from_pretrained( | |
pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision | |
) | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path, subfolder="vae", revision=revision | |
) | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, subfolder="unet", revision=revision | |
) | |
vae.requires_grad_(False) | |
text_encoder_one.requires_grad_(False) | |
text_encoder_two.requires_grad_(False) | |
weight_dtype = torch.float32 | |
if mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
unet.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=torch.float32) | |
text_encoder_one.to(device, dtype=weight_dtype) | |
text_encoder_two.to(device, dtype=weight_dtype) | |
if scale_lr: | |
learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size | |
optimizer_class = torch.optim.AdamW | |
tokenizers = [tokenizer_one, tokenizer_two] | |
text_encoders = [text_encoder_one, text_encoder_two] | |
@torch.no_grad() | |
def compute_embeddings(prompt, text_encoders, tokenizers): | |
original_size = (resolution, resolution) | |
target_size = (resolution, resolution) | |
crops_coords_top_left = (crops_coords_top_left_h, crops_coords_top_left_w) | |
prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
text_encoders, tokenizers, prompt | |
) | |
add_text_embeds = pooled_prompt_embeds | |
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids | |
add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
add_time_ids = torch.tensor([add_time_ids]) | |
prompt_embeds = prompt_embeds.to(device) | |
add_text_embeds = add_text_embeds.to(device) | |
add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype) | |
unet_added_cond_kwargs = { | |
"text_embeds": add_text_embeds * 0.0, | |
"time_ids": add_time_ids, | |
} | |
return prompt_embeds, unet_added_cond_kwargs | |
instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings( | |
instance_prompt, text_encoders, tokenizers | |
) | |
del tokenizers, text_encoders | |
# prompt_aux_param = nn.Parameter(torch.zeros_like(instance_prompt_hidden_states)) | |
# prompt_aux_param.requires_grad_(True) | |
# prompt_aux_param.to(device) | |
prompt_aux_param = 0.0 | |
param_to_optimize = [] | |
# fine tune only attn weights | |
for name, param in unet.named_parameters(): | |
if "transformer" in name and "weight" in name: | |
param.requires_grad_(True) | |
param_to_optimize.append(param) | |
print(name) | |
else: | |
param.requires_grad_(False) | |
# Optimizer creation | |
params_to_optimize = [ | |
# { | |
# "params": prompt_aux_param, | |
# "lr" : 4e-4, | |
# "weight_decay": 0.1, | |
# }, | |
{ | |
"params": param_to_optimize, | |
"lr": learning_rate, | |
}, | |
] | |
adam_beta1: float = 0.9 | |
adam_beta2: float = 0.999 | |
adam_weight_decay: float = 1e-2 | |
adam_epsilon: float = 1e-8 | |
optimizer = optimizer_class( | |
params_to_optimize, | |
lr=learning_rate, | |
betas=(adam_beta1, adam_beta2), | |
weight_decay=adam_weight_decay, | |
eps=adam_epsilon, | |
) | |
gc.collect() | |
torch.cuda.empty_cache() | |
train_dataset = Text2ImageDataset( | |
instance_data_root=instance_data_dir, | |
instance_prompt=instance_prompt, | |
size=resolution, | |
center_crop=center_crop, | |
instance_prompt_hidden_states=instance_prompt_hidden_states, | |
instance_unet_added_conditions=instance_unet_added_conditions, | |
) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=True, | |
collate_fn=lambda examples: collate_fn(examples), | |
num_workers=dataloader_num_workers, | |
) | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / gradient_accumulation_steps | |
) | |
if max_train_steps is None: | |
max_train_steps = num_train_epochs * num_update_steps_per_epoch | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
num_training_steps=max_train_steps * gradient_accumulation_steps, | |
num_cycles=lr_num_cycles, | |
power=lr_power, | |
) | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / gradient_accumulation_steps | |
) | |
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
total_batch_size = train_batch_size * gradient_accumulation_steps | |
print("***** Running training *****") | |
print(f" Num examples = {len(train_dataset)}") | |
print(f" Num batches each epoch = {len(train_dataloader)}") | |
print(f" Num Epochs = {num_train_epochs}") | |
print(f" Instantaneous batch size per device = {train_batch_size}") | |
print( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
) | |
print(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
print(f" Total optimization steps = {max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(global_step, max_train_steps)) | |
progress_bar.set_description("Steps") | |
for epoch in range(first_epoch, num_train_epochs): | |
unet.train() | |
for step, batch in enumerate(train_dataloader): | |
progress_bar.update(1) | |
global_step += 1 | |
print(f"step: {global_step}, epoch: {epoch}") | |
# Convert images to latent space | |
model_input = vae.encode( | |
batch["pixel_values"].to(device) | |
).latent_dist.sample() | |
model_input = model_input * vae.config.scaling_factor | |
model_input = model_input.to(weight_dtype) | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(model_input) | |
bsz = model_input.shape[0] | |
# Sample a random timestep for each image | |
timesteps = torch.randint( | |
0, | |
noise_scheduler.config.num_train_timesteps, | |
(bsz,), | |
device=model_input.device, | |
) | |
timesteps = timesteps.long() | |
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) | |
# Predict the noise residual | |
model_pred = unet( | |
noisy_model_input, | |
timesteps, | |
batch["input_ids"] + prompt_aux_param, | |
added_cond_kwargs=batch["unet_added_conditions"], | |
).sample | |
target = noise | |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
if global_step % checkpointing_steps == 0: | |
unet.save_pretrained(f"{output_dir}/unet/checkpoint-{global_step}") | |
if epoch % validation_epochs == 0: | |
with torch.no_grad(): | |
pipeline = DiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path, | |
unet=unet, | |
torch_dtype=weight_dtype, | |
safety_checker=None, | |
revision=revision, | |
) | |
pipeline = pipeline.to(device) | |
pipeline.set_progress_bar_config(disable=True) | |
print(prompt_aux_param) | |
# run inference | |
generator = ( | |
torch.Generator(device=device).manual_seed(seed) if seed else None | |
) | |
pipeline_args = { | |
"prompt_embeds": instance_prompt_hidden_states + prompt_aux_param, | |
"pooled_prompt_embeds": instance_unet_added_conditions[ | |
"text_embeds" | |
], | |
"height": 1024, | |
"width": 1024, | |
} | |
images = [ | |
pipeline(**pipeline_args, generator=generator).images[0] | |
for _ in range(num_validation_images) | |
] | |
if tracker == "wandb": | |
import wandb | |
wandb.log( | |
{ | |
"validation": [ | |
wandb.Image(image, caption=f"{i}: {validation_prompt}") | |
for i, image in enumerate(images) | |
] | |
} | |
) | |
del pipeline | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment