Skip to content

Instantly share code, notes, and snippets.

@mattdesl
Created November 2, 2023 20:29
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattdesl/3909a47507c5c219eed928bee4f297ea to your computer and use it in GitHub Desktop.
Save mattdesl/3909a47507c5c219eed928bee4f297ea to your computer and use it in GitHub Desktop.
fork of latent consistency model with a couple small perf tweaks
import os
import torch
import time
from diffusers import DiffusionPipeline, AutoencoderTiny
from collections import namedtuple
PredictionResult = namedtuple('PredictionResult', [
'latents',
'step',
'steps',
'prompt_embeds'
])
class Predictor:
def __init__(self, with_fast = True):
self.pipe = self._load_model()
self.device = self.pipe._execution_device
self.pipe.set_progress_bar_config(disable=True)
if with_fast:
self.pipe.fast_vae = AutoencoderTiny.from_pretrained(
"madebyollin/taesd", torch_dtype=torch.float32, use_safetensors=True
).to(self.device)
self.do_predict = torch.compile(self._do_predict)
def _load_model(self):
model = DiffusionPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
custom_pipeline="latent_consistency_txt2img",
custom_revision="main",
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
)
model.to(torch_device="cpu", torch_dtype=torch.float32).to('mps:0')
return model
def run_generate (self, prompt, seed=None, **kwargs):
seed = seed or int.from_bytes(os.urandom(2), "big")
for data in self.generate(prompt=prompt, seed=seed, **kwargs):
images = self.latent_to_image(data[0])
image = images[0]
output_path = self.save_result(image,prompt,seed,data[1])
print(f"{data[1]+1} of {data[2]} image saved to: {output_path}")
def run (self, seed=None, **kwargs):
seed = seed or int.from_bytes(os.urandom(2), "big")
for data in self.generate(seed=seed, **kwargs, intermediate_steps=False):
if data.step == data.steps - 1:
return data
def generate (self,
prompt = None,
seed = None,
steps = 4,
prompt_embeds = None,
**kwargs):
seed = seed or int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
torch.manual_seed(seed)
if prompt_embeds is None:
prompt_embeds = self.encode_prompt(prompt)
yield from self._do_predict(
prompt_embeds=prompt_embeds,
lcm_origin_steps=50,
**kwargs,
num_inference_steps=steps
)
@torch.no_grad()
def encode_prompt (self, prompt):
device = self.device
return self.pipe._encode_prompt(
prompt, device, 1, prompt_embeds=None
)
@torch.no_grad()
def _do_predict (self,
prompt = None,
height = 512,
width = 512,
guidance_scale = 7.5,
num_images_per_prompt = 1,
latents = None,
num_inference_steps = 4,
lcm_origin_steps = 50,
prompt_embeds = None,
cross_attention_kwargs = None,
intermediate_steps = True
):
pipe = self.pipe
# 0. Default height and width to unet
height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self.device
# 3. Encode input prompt
prompt_embeds = pipe._encode_prompt(
prompt,
device,
num_images_per_prompt,
prompt_embeds=prompt_embeds,
)
# 4. Prepare timesteps
pipe.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
timesteps = pipe.scheduler.timesteps
# 5. Prepare latent variable
num_channels_latents = pipe.unet.config.in_channels
latents = pipe.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
latents,
)
bs = batch_size * num_images_per_prompt
# 6. Get Guidance Scale Embedding
w = torch.tensor(guidance_scale).repeat(bs)
w_embedding = pipe.get_w_embedding(w, embedding_dim=256).to(device)
# 7. LCM MultiStep Sampling Loop:
for i, t in enumerate(timesteps):
ts = torch.full((bs,), t, device=device, dtype=torch.long)
# model prediction (v-prediction, eps, x)
model_pred = pipe.unet(
latents,
ts,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False)[0]
# compute the previous noisy sample x_t -> x_t-1
latents, denoised = pipe.scheduler.step(model_pred, i, t, latents, return_dict=False)
should_ret = intermediate_steps or i == num_inference_steps - 1
if should_ret:
yield PredictionResult(denoised, i, num_inference_steps, prompt_embeds)
@torch.no_grad()
def latent_to_image (self, latent, fast=False):
pipe = self.pipe
vae = self.pipe.fast_vae if fast else self.pipe.vae
image = vae.decode(latent / vae.config.scaling_factor, return_dict=False)[0]
do_denormalize = [True] * image.shape[0]
image = torch.stack(
[pipe.image_processor.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)
image = pipe.image_processor.pt_to_numpy(image.detach())
image = pipe.image_processor.numpy_to_pil(image)
return image
def save_image(self, result, seed, steps, i=0):
timestamp = time.strftime("%Y%m%d-%H%M%S")
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_path = os.path.join(output_dir, f"{timestamp}-seed-{seed}-steps-{steps}-i-{i}.png")
result.save(output_path)
print(f"Output image saved to: {output_path}")
return output_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment