Skip to content

Instantly share code, notes, and snippets.

@rovo79
Created March 12, 2024 11:16
Show Gist options
  • Save rovo79/09405fea787c05986fa7aad16995fb85 to your computer and use it in GitHub Desktop.
Save rovo79/09405fea787c05986fa7aad16995fb85 to your computer and use it in GitHub Desktop.
from turtle import forward
from pyparsing import Forward
import torch
from PIL import Image
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
# from diffusers import UNet2DModel
repo_id = "/Volumes/Acasis1TB/machine_learning/stable-diffusion-xl-base-1.0"
torch_device = "mps"
# model = UNet2DConditionModel.from_pretrained(
# repo_id, subfolder="unet", use_safetensors=True
# )
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", use_safetensors=True)
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
repo_id, subfolder="text_encoder", use_safetensors=True
)
model = UNet2DConditionModel.from_pretrained(
repo_id, subfolder="unet", use_safetensors=True
)
vae.to(torch_device)
text_encoder.to(torch_device)
model.to(torch_device)
# print(model.config)
# model = attributes = dir(model)
# print(attributes)
# print(dir(model.forward))
# help(model.forward)
prompt = ["a photograph of an astronaut riding a horse"]
height = 1024 # default height of Stable Diffusion
width = 1024 # default width of Stable Diffusion
num_inference_steps = 15 # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.mps.manual_seed(
0
) # Seed generator to create the initial latent noise
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# for name, param in model.named_parameters():
# if "some_keyword" in name: # Adjust the keyword
# print(name, param)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# print(text_embeddings, text_embeddings.shape)
# added_cond_kwargs = {"text_embeds": text_embeddings}
# print("added_cond_kwargs (before model call):", added_cond_kwargs)
latents = torch.randn(
(batch_size, model.config.in_channels, height // 8, width // 8),
generator=generator,
device=torch_device,
)
latents = latents * scheduler.init_noise_sigma
from tqdm.auto import tqdm
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
time_ids = torch.tensor(
[t.item()] * batch_size, device=torch_device, dtype=torch.float
) # Change to dtype=torch.float
time_embeds = model.time_embedding(time_ids).unsqueeze(1) # Added unsqueeze
print(time_embeds.shape)
added_cond_kwargs = {"text_embeds": text_embeddings, "time_embeds": time_embeds}
print("added_cond_kwargs (before model call):", added_cond_kwargs.shape)
# predict the noise residual
with torch.no_grad():
noise_pred = model(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).numpy()
images = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
image.save("generated_image.png")
# sample (torch.FloatTensor) — The noisy input tensor with the following shape (batch, channel, height, width).
# sample = [batch, channel, height, width]
# timestep (torch.FloatTensor or float or int) — The number of timesteps to denoise an input.
# encoder_hidden_states (torch.FloatTensor) — The encoder hidden states with shape (batch, sequence_length, feature_dim).
# image = model(sample, timestep, encoder_hidden_states).images[0]
# image = model(num_inference_steps=20).images[0]
# image.save("generated_image.png")
# mlmodel = ct.convert(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment