Skip to content

Instantly share code, notes, and snippets.

@rovo79
Created March 12, 2024 00:49
Show Gist options
  • Save rovo79/6308f980747f2155927aa4f010d8d396 to your computer and use it in GitHub Desktop.
Save rovo79/6308f980747f2155927aa4f010d8d396 to your computer and use it in GitHub Desktop.
import torch
from PIL import Image
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
# from diffusers import UNet2DModel
repo_id = "/Volumes/Acasis1TB/machine_learning/stable-diffusion-xl-base-1.0"
torch_device = "mps"
# model = UNet2DConditionModel.from_pretrained(
# repo_id, subfolder="unet", use_safetensors=True
# )
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", use_safetensors=True)
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
repo_id, subfolder="text_encoder", use_safetensors=True
)
model = UNet2DConditionModel.from_pretrained(
repo_id, subfolder="unet", use_safetensors=True
)
vae.to(torch_device)
text_encoder.to(torch_device)
model.to(torch_device)
print(model.config)
prompt = ["a photograph of an astronaut riding a horse"]
height = 1024 # default height of Stable Diffusion
width = 1024 # default width of Stable Diffusion
num_inference_steps = 15 # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.mps.manual_seed(
0
) # Seed generator to create the initial latent noise
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, model.config.in_channels, height // 8, width // 8),
generator=generator,
device=torch_device,
)
latents = latents * scheduler.init_noise_sigma
from tqdm.auto import tqdm
print(latents)
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
with torch.no_grad():
noise_pred = model(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).numpy()
images = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
image.save("generated_image.png")
# sample (torch.FloatTensor) — The noisy input tensor with the following shape (batch, channel, height, width).
# sample = [batch, channel, height, width]
# timestep (torch.FloatTensor or float or int) — The number of timesteps to denoise an input.
# encoder_hidden_states (torch.FloatTensor) — The encoder hidden states with shape (batch, sequence_length, feature_dim).
# image = model(sample, timestep, encoder_hidden_states).images[0]
# image = model(num_inference_steps=20).images[0]
# image.save("generated_image.png")
# mlmodel = ct.convert(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment