-
-
Save attila-dusnoki-htec/4142c8028429a70930bbc0d87e66ef06 to your computer and use it in GitHub Desktop.
Example script to use Stable Diffusion with Torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install diffusers transformers | |
import torch | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler | |
from tqdm.auto import tqdm | |
from PIL import Image | |
model_id = "stabilityai/stable-diffusion-2-1" | |
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, | |
subfolder="scheduler") | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder").to(torch_device) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(torch_device) | |
unet = UNet2DConditionModel.from_pretrained(model_id, | |
subfolder="unet").to(torch_device) | |
# params | |
prompt = ["a photograph of an astronaut riding a horse"] | |
height = 512 # default height of Stable Diffusion | |
width = 512 # default width of Stable Diffusion | |
num_inference_steps = 20 # Number of denoising steps | |
guidance_scale = 7.0 # Scale for classifier-free guidance | |
seed = 13 | |
generator = torch.manual_seed( | |
seed) # Seed generator to create the inital latent noise | |
batch_size = 1 | |
# Conditional | |
text_input = tokenizer(prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
# Unconditional | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer([""] * batch_size, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt") | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder( | |
uncond_input.input_ids.to(torch_device))[0] | |
# Concat embeddings | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# Random input | |
latents = torch.randn( | |
(batch_size, 4, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latents.to(torch_device) | |
# Use sigmas | |
scheduler.set_timesteps(num_inference_steps) | |
latents = latents * scheduler.init_noise_sigma | |
# Unet | |
for t in tqdm(scheduler.timesteps): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = unet(latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings).sample | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - | |
noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
# Decoder | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
# Save images | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
for idx, img in enumerate(pil_images): | |
filename = f"output_{seed}_{idx}.png" | |
img.save(filename, format="png") | |
print(f"Image saved to {filename}") | |
print("done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment