Created
November 9, 2023 11:22
-
-
Save pcuenca/1de59658f190f4c44061a506ca425de8 to your computer and use it in GitHub Desktop.
Sayak's LCM benchmark, modified
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.benchmark as benchmark | |
import argparse | |
from diffusers import DiffusionPipeline, LCMScheduler | |
PROMPT = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux" | |
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
LORA_ID = "lcm-sd/lcm-sdxl-lora-huber-final" | |
def benchmark_fn(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
return t0.blocked_autorange().mean * 1e6 | |
def load_pipeline(standard_sdxl=False): | |
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, variant="fp16") | |
if not standard_sdxl: | |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
pipe.load_lora_weights(LORA_ID, weight_name="lcm_sdxl_lora.bin") | |
pipe.to(device="cuda", dtype=torch.float16) | |
return pipe | |
def call_pipeline(pipe, batch_size, num_inference_steps, guidance_scale): | |
images = pipe( | |
prompt=PROMPT, | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=batch_size, | |
guidance_scale=guidance_scale, | |
).images[0] | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--standard_sdxl", action="store_true") | |
args = parser.parse_args() | |
pipeline = load_pipeline(args.standard_sdxl) | |
if args.standard_sdxl: | |
num_inference_steps = 25 | |
guidance_scale = 5 | |
else: | |
num_inference_steps = 4 | |
guidance_scale = 1 | |
# warmup | |
call_pipeline(pipeline, args.batch_size, num_inference_steps, guidance_scale) | |
time = benchmark_fn(call_pipeline, pipeline, args.batch_size, num_inference_steps, guidance_scale) | |
print(f"Batch size: {args.batch_size} in {time/1e6:.3f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment