Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Created August 23, 2022 10:23
Show Gist options
  • Save afiaka87/26018065695e98e0e9ac576fa2e0a065 to your computer and use it in GitHub Desktop.
Save afiaka87/26018065695e98e0e9ac576fa2e0a065 to your computer and use it in GitHub Desktop.
import os
from typing import List
import numpy as np
import torch
from cog import BasePredictor, Input, Path
from diffusers import (
AutoencoderKL,
LMSDiscreteScheduler,
UNet2DConditionModel,
)
from PIL import Image
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.output_dir = Path("cog_output") # TODO
self.output_dir.mkdir(exist_ok=True)
cache_dir = "model_cache"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# the autoencoder compresses images into a more compact representation, and then decodes them back into pixels.
self.vae = AutoencoderKL.from_pretrained(
cache_dir,
subfolder="vae",
)
self.vae.to(self.device)
print("Loaded autoencoder")
# Words are first tokenized into integers that CLIP understands.
# These tokens are then passed to the text encoder to produce a vector representation (768 floating point numbers) of the text.
self.tokenizer = CLIPTokenizer.from_pretrained(
cache_dir,
subfolder="tokenizer",
) # tokenizer is small enough for CPU
self.text_encoder = CLIPTextModel.from_pretrained(
cache_dir,
subfolder="text_encoder",
)
self.text_encoder.to(self.device)
print("Loaded CLIP text encoder.")
# To generate latents from the text, a denoising diffusion UNet model is used.
# This model is trained to generate autoencoder latents from text.
# The autoencoder can then be used to decode the latents into image space.
self.unet = UNet2DConditionModel.from_pretrained(
"model_cache",
subfolder="unet",
revision="fp16",
torch_dtype=torch.float16,
)
self.unet.to(self.device)
print("Loaded unet")
self.scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
@torch.inference_mode() # disables dropout, autograd, etc.
@torch.cuda.amp.autocast() # automatically casts to fp16
def predict(
self,
prompt: str = Input(description="Input prompt", default=""),
image_prompt: Path = Input(description="Input image prompt", default=None),
num_outputs: int = Input(
description="Number of images to output", choices=[1, 4, 16], default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=100
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
),
height: int = Input(
description="Height of output images",
default=512,
choices=[256, 384, 512, 640, 768, 1024],
),
width: int = Input(
description="Width of output images",
default=512,
choices=[256, 384, 512, 640, 768, 1024],
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
image_prompt_strength: float = Input(
description="Strength of image prompt", ge=0, le=1, default=0.5
),
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
generator = torch.manual_seed(seed)
print(f"Using seed: {seed}")
self.scheduler.set_timesteps(num_inference_steps)
prompt = [prompt] * num_outputs
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * num_outputs,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
text_embeddings = torch.cat(
[
uncond_embeddings.to(self.device),
text_embeddings.to(self.device),
]
)
latents = torch.randn(
(num_outputs, self.unet.in_channels, height // 8, width // 8),
generator=generator,
)
if image_prompt is not None:
init_image = Image.open(image_prompt).convert("RGB")
init_image = init_image.resize((int(width), int(height)), Image.LANCZOS)
init_image = (
TF.to_tensor(init_image).to(self.device).unsqueeze(0).clamp(0, 1)
)
latents = (
self.vae.encode(init_image.to(self.device) * 2 - 1).sample() * 0.18215
)
latents = latents.to(self.device)
# TODO - need to add noise to the latents, for the correct timesteps, then start the scheduler halfway through somehow.
# yikes.
print(f"Using {num_inference_steps} inference steps")
latents = latents * self.scheduler.sigmas[0]
# generate latents from the noise
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
)["sample"]
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, i, latents)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
images = self.vae.decode(latents)
# save the images
images = (images / 2 + 0.5).clamp(0, 1)
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
prediction_paths = []
for idx, image in enumerate(images):
image = Image.fromarray(image)
image.save(self.output_dir / f"{idx:03d}.png")
prediction_paths.append(self.output_dir / f"{idx:03d}.png")
return prediction_paths
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment