Skip to content

Instantly share code, notes, and snippets.

@ChristopherMWood
Last active July 29, 2023 21:53
Show Gist options
  • Save ChristopherMWood/3f25870dc91576c221f8410e6020cdc6 to your computer and use it in GitHub Desktop.
Save ChristopherMWood/3f25870dc91576c221f8410e6020cdc6 to your computer and use it in GitHub Desktop.
Batch runner for apple/ml-stable-diffusion

Batch runner for apple/ml-stable-diffusion

This is a script modification for the apple/ml-stable-diffusion repo that enables it to run batches of image generations instead of just one at at time. When using the original repo, the time it took to load the libraries into place was considerably longer than it took to actually generate a single image for me. This is based on me running it against a Gen 1 8GB Mac Pro M1. This version will load the libraries as usual, then take as many prompts as wanted from a CSV and run through them all without needing to reload the libraries for each.

How to use

Step 1: Setup apple/ml-stable-diffusion

Get the original apple/ml-stable-diffusion running on your machine. I used this guide. Make sure it is all up and running and that you can generate an image before doing step #2.

Step 2: Replace pipeline.py

Replace the contents of your pipeline.py file with the version below.

Step 3: Create an inputs.csv file

Create a csv named inputs.csv that follows the column formatting below. You can add as many rows as you want.

Column Type Description
Column 1: The prompt String The prompt to be used by Stable Diffusion. Must be in double quotes to allow commas to be used in the prompts.
Column 2: # Of iterations Int The number of iterations to generate for the prompt

Step 4: Run

The original command required a prompt and optional seed argument, but since the prompts and seeds are now in the csv, you can ignore them and use this simplified command.

python -m python_coreml_stable_diffusion.pipeline -i models -o outputs

NOTE: I found the NSFW filter a bit too aggressive and that it blocked a lot of plain prompts, so it has been disabled in the script. As stated in their documentation, ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered images.

a renaissance painting of a person creating a github gist on a laptop,stylized,classy 5
a illustration of a person spending too much time on their computer,colorful,moody 3
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import argparse
import csv
import random
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin
import gc
import inspect
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import numpy as np
import os
from python_coreml_stable_diffusion.coreml_model import (
CoreMLModel,
_load_mlpackage,
get_available_compute_units,
)
import time
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from typing import List, Optional, Union
class GenerationModel:
def __init__(self):
self.prompt = 0
self.variations = 1
self.seed = 0
self.o = 0
self.compute_unit = 0
self.model_version = 0
self.scheduler = 0
self.num_inference_steps = 0
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
""" Core ML version of
`diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline`
"""
def __init__(
self,
text_encoder: CoreMLModel,
unet: CoreMLModel,
vae_decoder: CoreMLModel,
feature_extractor: CLIPFeatureExtractor,
safety_checker: Optional[CoreMLModel],
scheduler: Union[DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler],
tokenizer: CLIPTokenizer,
):
super().__init__()
# Register non-Core ML components of the pipeline similar to the original pipeline
self.register_modules(
tokenizer=tokenizer,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
if safety_checker is None:
# Reproduce original warning:
# https://github.com/huggingface/diffusers/blob/v0.9.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L119
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# Register Core ML components of the pipeline
# self.safety_checker = safety_checker
self.text_encoder = text_encoder
self.unet = unet
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1]
self.vae_decoder = vae_decoder
VAE_DECODER_UPSAMPLE_FACTOR = 8
# In PyTorch, users can determine the tensor shapes dynamically by default
# In CoreML, tensors have static shapes unless flexible shapes were used during export
# See https://coremltools.readme.io/docs/flexible-inputs
latent_h, latent_w = self.unet.expected_inputs["sample"]["shape"][2:]
self.height = latent_h * VAE_DECODER_UPSAMPLE_FACTOR
self.width = latent_w * VAE_DECODER_UPSAMPLE_FACTOR
logger.info(
f"Stable Diffusion configured to generate {self.height}x{self.width} images"
)
def _encode_prompt(self, prompt, num_images_per_prompt,
do_classifier_free_guidance, negative_prompt):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(
text_input_ids[:, self.tokenizer.model_max_length:])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
text_input_ids = text_input_ids[:, :self.tokenizer.
model_max_length]
text_embeddings = self.text_encoder(
input_ids=text_input_ids.astype(np.float32))["last_hidden_state"]
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}.")
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`.")
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.astype(
np.float32))["last_hidden_state"]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate(
[uncond_embeddings, text_embeddings])
text_embeddings = text_embeddings.transpose(0, 2, 1)[:, :, None, :]
return text_embeddings
def run_safety_checker(self, image):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image),
return_tensors="np",
)
safety_checker_outputs = self.safety_checker(
clip_input=safety_checker_input.pixel_values.astype(
np.float16),
images=image.astype(np.float16),
adjustment=np.array([0.]).astype(
np.float16), # defaults to 0 in original pipeline
)
# Unpack dict
has_nsfw_concept = safety_checker_outputs["has_nsfw_concepts"]
image = safety_checker_outputs["filtered_images"]
concept_scores = safety_checker_outputs["concept_scores"]
logger.info(
f"Generated image has nsfw concept={has_nsfw_concept.any()}")
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae_decoder(z=latents.astype(np.float16))["image"]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def prepare_latents(self,
batch_size,
num_channels_latents,
height,
width,
latents=None):
latents_shape = (batch_size, num_channels_latents, self.height // 8,
self.width // 8)
if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float16)
elif latents.shape != latents_shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
)
latents = latents * self.scheduler.init_noise_sigma
return latents
def check_inputs(self, prompt, height, width, callback_steps):
if height != self.height or width != self.width:
logger.warning(
"`height` and `width` dimensions (of the output image tensor) are fixed when exporting the Core ML models " \
"unless flexible shapes are used during export (https://coremltools.readme.io/docs/flexible-inputs). " \
"This pipeline was provided with Core ML models that generate {self.height}x{self.width} images (user requested {height}x{width})"
)
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (callback_steps is not None and
(not isinstance(callback_steps, int)
or callback_steps <= 0)):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}.")
def prepare_extra_step_kwargs(self, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
return extra_step_kwargs
def __call__(
self,
prompt,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
negative_prompt=None,
num_images_per_prompt=1,
eta=0.0,
latents=None,
output_type="pil",
return_dict=True,
callback=None,
callback_steps=1,
**kwargs,
):
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
if batch_size > 1 or num_images_per_prompt > 1:
raise NotImplementedError(
"For batched generation of multiple images and/or multiple prompts, please refer to the Swift package."
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
latents,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input.astype(np.float16),
timestep=np.array([t, t], np.float16),
encoder_hidden_states=text_embeddings.astype(np.float16),
)["noise_pred"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents),
**extra_step_kwargs,
).prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
has_nsfw_concept = False
# image, has_nsfw_concept = self.run_safety_checker(image)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept)
def get_available_schedulers():
schedulers = {}
for scheduler in [DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler]:
schedulers[scheduler().__class__.__name__.replace("Scheduler", "")] = scheduler
return schedulers
SCHEDULER_MAP = get_available_schedulers()
def get_coreml_pipe(pytorch_pipe,
mlpackages_dir,
model_version,
compute_unit,
delete_original_pipe=True,
scheduler_override=None):
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original
diffusers PyTorch pipeline
"""
# Ensure `scheduler_override` object is of correct type if specified
if scheduler_override is not None:
assert isinstance(scheduler_override, SchedulerMixin)
logger.warning(
"Overriding scheduler in pipeline: "
f"Default={pytorch_pipe.scheduler}, Override={scheduler_override}")
# Gather configured tokenizer and scheduler attributes from the original pipe
coreml_pipe_kwargs = {
"tokenizer": pytorch_pipe.tokenizer,
"scheduler": pytorch_pipe.scheduler if scheduler_override is None else scheduler_override,
"feature_extractor": pytorch_pipe.feature_extractor,
}
model_names_to_load = ["text_encoder", "unet", "vae_decoder"]
if getattr(pytorch_pipe, "safety_checker", None) is not None:
model_names_to_load.append("safety_checker")
else:
logger.warning(
f"Original diffusers pipeline for {model_version} does not have a safety_checker, "
"Core ML pipeline will mirror this behavior.")
coreml_pipe_kwargs["safety_checker"] = None
if delete_original_pipe:
del pytorch_pipe
gc.collect()
logger.info("Removed PyTorch pipe to reduce peak memory consumption")
# Load Core ML models
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}")
coreml_pipe_kwargs.update({
model_name: _load_mlpackage(
model_name,
mlpackages_dir,
model_version,
compute_unit,
)
for model_name in model_names_to_load
})
logger.info("Done.")
logger.info("Initializing Core ML pipe for image generation")
coreml_pipe = CoreMLStableDiffusionPipeline(**coreml_pipe_kwargs)
logger.info("Done.")
return coreml_pipe
def get_image_path(args, **override_kwargs):
""" mkdir output folder and encode metadata in the filename
"""
out_folder = os.path.join(args.o, "_".join(args.prompt.replace("/", "_").rsplit(" ")))
os.makedirs(out_folder, exist_ok=True)
out_fname = f"randomSeed_{override_kwargs.get('seed', None) or args.seed}"
out_fname += f"_computeUnit_{override_kwargs.get('compute_unit', None) or args.compute_unit}"
out_fname += f"_modelVersion_{override_kwargs.get('model_version', None) or args.model_version.replace('/', '_')}"
if args.scheduler is not None:
out_fname += f"_customScheduler_{override_kwargs.get('scheduler', None) or args.scheduler}"
out_fname += f"_numInferenceSteps{override_kwargs.get('num_inference_steps', None) or args.num_inference_steps}"
return os.path.join(out_folder, out_fname + ".png")
# Function to open and read csv file to get prompts
def getPromptGenerationModels(inputsFilename):
genModels = []
with open(inputsFilename, 'r') as file:
reader = csv.reader(file)
for row in reader:
genModel = GenerationModel()
genModel.prompt = row[0]
variations = int(row[1])
if variations > 0:
genModel.variations = variations
genModels.append(genModel)
return genModels
def main(args):
logger.info("Initializing PyTorch pipe for reference configuration")
from diffusers import StableDiffusionPipeline
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version,
use_auth_token=True)
user_specified_scheduler = None
if args.scheduler is not None:
user_specified_scheduler = SCHEDULER_MAP[
args.scheduler].from_config(pytorch_pipe.scheduler.config)
# Loads the libraries and models into memory [Takes the most time so doing this only once for all images now]
coreml_pipe = get_coreml_pipe(pytorch_pipe=pytorch_pipe,
mlpackages_dir=args.i,
model_version=args.model_version,
compute_unit=args.compute_unit,
scheduler_override=user_specified_scheduler)
genModels = getPromptGenerationModels("inputs.csv")
# BEGIN
for genModel in genModels:
logger.info(f"Beginning image generation for prompt: {genModel.prompt}")
for variation in range(genModel.variations):
logger.info(f"Variation #{variation}")
# Set seed
randomSeed = random.randint(0, 1000000000)
np.random.seed(randomSeed)
genModel.seed = randomSeed
genModel.o = args.o
genModel.compute_unit = args.compute_unit
genModel.model_version = args.model_version
genModel.scheduler = args.scheduler
genModel.num_inference_steps = args.num_inference_steps
image = coreml_pipe(
prompt=genModel.prompt,
height=coreml_pipe.height,
width=coreml_pipe.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale
)
out_path = get_image_path(genModel)
logger.info(f"Saving generated image to {out_path}")
image["images"][0].save(out_path)
# END
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
required=False,
help="The text prompt to be used for text-to-image generation.")
parser.add_argument(
"-i",
required=True,
help=("Path to input directory with the .mlpackage files generated by "
"python_coreml_stable_diffusion.torch2coreml"))
parser.add_argument("-o", required=True)
parser.add_argument("--seed",
"-s",
default=93,
type=int,
help="Random seed to be able to reproduce results")
parser.add_argument(
"--model-version",
default="CompVis/stable-diffusion-v1-4",
help=
("The pre-trained model checkpoint and configuration to restore. "
"For available versions: https://huggingface.co/models?search=stable-diffusion"
))
parser.add_argument(
"--compute-unit",
choices=get_available_compute_units(),
default="ALL",
help=("The compute units to be used when executing Core ML models. "
f"Options: {get_available_compute_units()}"))
parser.add_argument(
"--scheduler",
choices=tuple(SCHEDULER_MAP.keys()),
default=None,
help=("The scheduler to use for running the reverse diffusion process. "
"If not specified, the default scheduler from the diffusers pipeline is utilized"))
parser.add_argument(
"--num-inference-steps",
default=50,
type=int,
help="The number of iterations the unet model will be executed throughout the reverse diffusion process")
parser.add_argument(
"--guidance-scale",
default=7.5,
type=float,
help="Controls the influence of the text prompt on sampling process (0=random images)")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment