Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Created September 6, 2022 14:53
Show Gist options
  • Save afiaka87/683cb3ab5115d192e0ef170651c6f834 to your computer and use it in GitHub Desktop.
Save afiaka87/683cb3ab5115d192e0ef170651c6f834 to your computer and use it in GitHub Desktop.
import inspect
from typing import List, Optional, Union, Tuple
import numpy as np
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.safety_checker import cosine_distance
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def preprocess_init_image(image: Image, width: int, height: int):
image = image.resize((width, height), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask: Image, width: int, height: int):
mask = mask.convert("L")
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = torch.from_numpy(mask)
return mask
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
From https://github.com/huggingface/diffusers/pull/241
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Optional[torch.FloatTensor],
mask: Optional[torch.FloatTensor],
width: int,
height: int,
prompt_strength: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
) -> Image:
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if prompt_strength < 0 or prompt_strength > 1:
raise ValueError(
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
)
if mask is not None and init_image is None:
raise ValueError(
"If mask is defined, then init_image also needs to be defined"
)
if width % 8 != 0 or height % 8 != 0:
raise ValueError("Width and height must both be divisible by 8")
# set timesteps
accepts_offset = "offset" in set(
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
)
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
text_embeddings = self.embed_text(prompt=prompt)
uncond_embeddings = self.embed_text(
prompt=[""] * batch_size
) # only need to call this once
if init_image is not None:
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
init_image,
prompt_strength,
offset,
num_inference_steps,
batch_size,
generator,
)
else:
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=self.device,
)
init_timestep = num_inference_steps
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
sd_hidden_states = torch.cat([uncond_embeddings, text_embeddings], dim=0)
else:
sd_hidden_states = text_embeddings
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=sd_hidden_states, # hidden state/context is the uncond embedding and the text embedding
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(
noise_pred, i, latents, **extra_step_kwargs
)["prev_sample"]
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)["prev_sample"]
# replace the unmasked part with original latents, with added noise
if mask is not None:
timesteps = self.scheduler.timesteps[t_start + i]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
noisy_init_latents = self.scheduler.add_noise(
init_latents_orig, mask_noise, timesteps
)
latents = noisy_init_latents * mask + latents * (1 - mask)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
decoded_image = self.vae.decode(latents)
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
decoded_image = decoded_image.cpu().permute(0,2,3,1).numpy()
# only need the "first" prompt since it is repeated batch_size times
pooled_text_embeddings = self.embed_text(prompt=[prompt[0]], pooled_output=True)
# get pooled image embeddings for cosine similarity (1, 768)
clip_image_tensors = self.feature_extractor(
self.numpy_to_pil(decoded_image), return_tensors="pt"
).to(self.device)
clip_image_tensors = clip_image_tensors.pixel_values
pooled_image_embeddings = self.image_to_embedding(image_tensors=clip_image_tensors)
scores = (
cosine_distance(pooled_image_embeddings, pooled_text_embeddings)
).detach().cpu().numpy()
pil_image = self.numpy_to_pil(decoded_image)
return {
"sample": pil_image,
"score": scores
# "nsfw_content_detected": has_nsfw_concept,
}
# # check safety
# _, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=image_tensors
# ) # TODO: check if this is correct
def latents_from_init_image(
self,
init_image: torch.FloatTensor,
prompt_strength: float,
offset: int,
num_inference_steps: int,
batch_size: int,
generator: Optional[torch.Generator],
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample()
init_latents = 0.18215 * init_latents
init_latents_orig = init_latents
# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * prompt_strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
return init_latents_orig, init_latents, init_timestep
def embed_text(
self,
prompt: List[str],
pooled_output: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Embed text using CLIP model."""
assert isinstance(
prompt, list
), f"prompt {prompt} must be a list of strings, try wrapping it in a list e.g. [prompt]"
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeds = self.text_encoder(
text_input.input_ids.to(self.device)
) # last_hiden_state (batch_size, seq_len, dim)
if pooled_output:
return text_embeds[1]
else:
return text_embeds[0]
def image_to_embedding(self, image_tensors: torch.FloatTensor) -> torch.FloatTensor:
pooled_image_embeddings = self.safety_checker.vision_model(
pixel_values=image_tensors
).pooler_output
pooled_image_embeddings = self.safety_checker.visual_projection(
pooled_image_embeddings
) # we have to project the pooled image embeddings to the same dimension as the text embeddings
return pooled_image_embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment