Skip to content

Instantly share code, notes, and snippets.

@KokeCacao
Created July 9, 2023 23:31
Show Gist options
  • Save KokeCacao/1206d1a1810d60db302013a935295444 to your computer and use it in GitHub Desktop.
Save KokeCacao/1206d1a1810d60db302013a935295444 to your computer and use it in GitHub Desktop.
Simple 100 lines implementation of 2D Dreamfusion
import os
join = os.path.join
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime
from utils.utils import save_image
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error() # disable warning
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import DDIMScheduler
class SpecifyGradient(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input_tensor, gt_grad):
ctx.save_for_backward(gt_grad)
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
@staticmethod
@custom_bwd
def backward(ctx, grad_scale):
gt_grad, = ctx.saved_tensors
gt_grad = gt_grad * grad_scale
return gt_grad, None
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32
model_path = "stabilityai/stable-diffusion-2-1-base"
prompt = "a photograph of an astronaut riding a horse"
lr = 0.03
start_time: datetime = datetime.now()
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=dtype)
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype)
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
scheduler.betas = scheduler.betas.to(device)
scheduler.alphas = scheduler.alphas.to(device)
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] , padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings_vsd = torch.cat([uncond_embeddings[0:1], text_embeddings[0:1]])
num_train_timesteps = len(scheduler.betas)
scheduler.set_timesteps(num_train_timesteps)
particles = torch.nn.Parameter(torch.randn(1, 3, 512, 512, requires_grad=True).to(device, dtype=dtype))
optimizer = torch.optim.Adam([particles], lr=lr)
pbar = tqdm(np.random.choice(list(range(num_train_timesteps)), 1000, replace=True))
### regular sd text to image generation
for step, chosen_t in enumerate(pbar):
rgb_BCHW_512 = F.interpolate(particles, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae
latents = 0.18215 * vae.encode(rgb_BCHW_512).latent_dist.sample()
# latents = get_latents(particles, args.rgb_as_latents, use_mlp_particle=args.use_mlp_particle)
t = torch.tensor([chosen_t]).to(device)
######## q sample #########
# random sample particle_num_vsd particles from latents
# indices = torch.randperm(latents.size(0))
latents_vsd = latents[0:1]
noise = torch.randn_like(latents_vsd)
noisy_latents = scheduler.add_noise(latents_vsd, noise, t)
######## Do the gradient for latents!!! #########
optimizer.zero_grad()
with torch.no_grad():
latent_model_input = torch.cat([noisy_latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings_vsd, cross_attention_kwargs={}).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
# noise_pred = predict_noise0_diffuser(unet, noisy_latents, text_embeddings_vsd, t, guidance_scale=7.5, cross_attention_kwargs={}, scheduler=scheduler)
grad = (noise_pred - noise)
noise_pred_phi = noise.detach().clone()
loss = SpecifyGradient.apply(noisy_latents, grad)
noise_pred = noise_pred.detach().clone()
loss.backward()
optimizer.step()
pbar.set_description(f'Loss: {loss.item():.6f}, sampled t : {t.item()}')
optimizer.zero_grad()
if step % 50 == 0:
tmp_latents = 1 / 0.18215 * latents_vsd.clone().detach()
pred_latents = scheduler.step(noise_pred, t, noisy_latents).pred_original_sample.to(dtype).clone().detach()
with torch.no_grad():
image_ = vae.decode(tmp_latents).sample.to(torch.float32)
save_image(image_[0], f"out/{start_time.strftime('%Y%m%d_%H%M%S')}/z_image_{step}.png")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment