Skip to content

Instantly share code, notes, and snippets.

@carson-katri
Created December 2, 2022 20:46
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save carson-katri/f51532b9d5162928d5cacbaee081a799 to your computer and use it in GitHub Desktop.
Save carson-katri/f51532b9d5162928d5cacbaee081a799 to your computer and use it in GitHub Desktop.
CONVERTED_MODEL_PATH = "path to converted v2 depth model"
DEVICE = "cuda"
INIT_IMAGE = "path to init image or None"
DEPTH_IMAGE = "path to black and white depth map"
OUTPUT_PATH = "path to save the result"
import diffusers
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
import torch
import PIL.Image
import numpy as np
from typing import Union, List, Optional, Callable
# This pipeline is mostly copied from StableDiffusionInpaintPipeline and StableDiffusionImg2ImgPipeline.
def prepare_depth(depth):
if isinstance(depth, PIL.Image.Image):
depth = np.array(depth.convert("L"))
depth = depth.astype(np.float32) / 255.0
depth = depth[None, None]
depth[depth < 0.5] = 0
depth[depth >= 0.5] = 1
depth = torch.from_numpy(depth)
return depth
class StableDiffusionDepthPipeline(diffusers.StableDiffusionInpaintPipeline):
def prepare_depth_latents(
self, depth, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
depth = torch.nn.functional.interpolate(
depth, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
depth = depth.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
depth = depth.repeat(batch_size, 1, 1, 1)
depth = torch.cat([depth] * 2) if do_classifier_free_guidance else depth
return depth
def prepare_img2img_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
image = image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
depth_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
strength: float = 0.8,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare the depth image
depth = prepare_depth(depth_image)
if image is not None and isinstance(image, PIL.Image.Image):
image = diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if image is not None:
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
if image is not None:
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_img2img_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Prepare mask latent variables
depth = self.prepare_depth_latents(
depth,
batch_size * num_images_per_prompt,
height,
width,
text_embeddings.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
num_channels_depth = depth.shape[1]
if num_channels_latents + num_channels_depth != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_depth}"
f" = {num_channels_latents+num_channels_depth}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, depth], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 11. Post-processing
image = self.decode_latents(latents)
# 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 13. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pipe: StableDiffusionDepthPipeline = StableDiffusionDepthPipeline.from_pretrained(CONVERTED_MODEL_PATH)
pipe = pipe.to(DEVICE)
pipe.enable_attention_slicing()
image = PIL.Image.open(INIT_IMAGE) if INIT_IMAGE is not None else None
depth = PIL.Image.open(DEPTH_IMAGE)
output = pipe(
prompt=PROMPT,
image=image.resize(depth.size),
depth_image=depth,
width=depth.size[0],
height=depth.size[1],
num_inference_steps=50
)
output.images[0].save(OUTPUT_PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment