Created
July 9, 2023 23:31
-
-
Save KokeCacao/1206d1a1810d60db302013a935295444 to your computer and use it in GitHub Desktop.
Simple 100 lines implementation of 2D Dreamfusion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
join = os.path.join | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from datetime import datetime | |
from utils.utils import save_image | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from transformers import logging as transformers_logging | |
transformers_logging.set_verbosity_error() # disable warning | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from diffusers import DDIMScheduler | |
class SpecifyGradient(torch.autograd.Function): | |
@staticmethod | |
@custom_fwd | |
def forward(ctx, input_tensor, gt_grad): | |
ctx.save_for_backward(gt_grad) | |
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. | |
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) | |
@staticmethod | |
@custom_bwd | |
def backward(ctx, grad_scale): | |
gt_grad, = ctx.saved_tensors | |
gt_grad = gt_grad * grad_scale | |
return gt_grad, None | |
def main(): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
dtype = torch.float32 | |
model_path = "stabilityai/stable-diffusion-2-1-base" | |
prompt = "a photograph of an astronaut riding a horse" | |
lr = 0.03 | |
start_time: datetime = datetime.now() | |
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype) | |
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) | |
unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=dtype) | |
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype) | |
vae = vae.to(device) | |
text_encoder = text_encoder.to(device) | |
unet = unet.to(device) | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
unet.requires_grad_(False) | |
scheduler.betas = scheduler.betas.to(device) | |
scheduler.alphas = scheduler.alphas.to(device) | |
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device) | |
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer([""] , padding="max_length", max_length=max_length, return_tensors="pt") | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings_vsd = torch.cat([uncond_embeddings[0:1], text_embeddings[0:1]]) | |
num_train_timesteps = len(scheduler.betas) | |
scheduler.set_timesteps(num_train_timesteps) | |
particles = torch.nn.Parameter(torch.randn(1, 3, 512, 512, requires_grad=True).to(device, dtype=dtype)) | |
optimizer = torch.optim.Adam([particles], lr=lr) | |
pbar = tqdm(np.random.choice(list(range(num_train_timesteps)), 1000, replace=True)) | |
### regular sd text to image generation | |
for step, chosen_t in enumerate(pbar): | |
rgb_BCHW_512 = F.interpolate(particles, (512, 512), mode="bilinear", align_corners=False) | |
# encode image into latents with vae | |
latents = 0.18215 * vae.encode(rgb_BCHW_512).latent_dist.sample() | |
# latents = get_latents(particles, args.rgb_as_latents, use_mlp_particle=args.use_mlp_particle) | |
t = torch.tensor([chosen_t]).to(device) | |
######## q sample ######### | |
# random sample particle_num_vsd particles from latents | |
# indices = torch.randperm(latents.size(0)) | |
latents_vsd = latents[0:1] | |
noise = torch.randn_like(latents_vsd) | |
noisy_latents = scheduler.add_noise(latents_vsd, noise, t) | |
######## Do the gradient for latents!!! ######### | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
latent_model_input = torch.cat([noisy_latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings_vsd, cross_attention_kwargs={}).sample | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond) | |
# noise_pred = predict_noise0_diffuser(unet, noisy_latents, text_embeddings_vsd, t, guidance_scale=7.5, cross_attention_kwargs={}, scheduler=scheduler) | |
grad = (noise_pred - noise) | |
noise_pred_phi = noise.detach().clone() | |
loss = SpecifyGradient.apply(noisy_latents, grad) | |
noise_pred = noise_pred.detach().clone() | |
loss.backward() | |
optimizer.step() | |
pbar.set_description(f'Loss: {loss.item():.6f}, sampled t : {t.item()}') | |
optimizer.zero_grad() | |
if step % 50 == 0: | |
tmp_latents = 1 / 0.18215 * latents_vsd.clone().detach() | |
pred_latents = scheduler.step(noise_pred, t, noisy_latents).pred_original_sample.to(dtype).clone().detach() | |
with torch.no_grad(): | |
image_ = vae.decode(tmp_latents).sample.to(torch.float32) | |
save_image(image_[0], f"out/{start_time.strftime('%Y%m%d_%H%M%S')}/z_image_{step}.png") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment