Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RohitDhankar/80b893ebcfcb68e7a58d78a38a1b9482 to your computer and use it in GitHub Desktop.
Save RohitDhankar/80b893ebcfcb68e7a58d78a38a1b9482 to your computer and use it in GitHub Desktop.
"""
Examples:
(1) python benchmark_distilled_sd.py --pipeline_id CompVis/stable-diffusion-v1-4
(2) python benchmark_distilled_sd.py --pipeline_id CompVis/stable-diffusion-v1-4 --vae_path sayakpaul/taesd-diffusers
(3) python benchmark_distilled_sd.py --pipeline_id nota-ai/bk-sdm-small
(4) python benchmark_distilled_sd.py --pipeline_id nota-ai/bk-sdm-small --vae_path sayakpaul/taesd-diffusers
"""
import argparse
import time
import torch
from diffusers import AutoencoderTiny, DiffusionPipeline
NUM_ITERS_TO_RUN = 3
NUM_INFERENCE_STEPS = 25
NUM_IMAGES_PER_PROMPT = 4
PROMPT = "a golden vase with different flowers"
SEED = 0
def load_pipeline(pipeline_id, vae_path=None):
pipe = DiffusionPipeline.from_pretrained(pipeline_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
if vae_path is not None:
pipe.vae = AutoencoderTiny.from_pretrained(
vae_path, torch_dtype=torch.float16
).to("cuda")
return pipe
def run_inference(args):
torch.cuda.reset_peak_memory_stats()
pipe = load_pipeline(args.pipeline_id, args.vae_path)
start = time.time_ns()
for _ in range(NUM_ITERS_TO_RUN):
images = pipe(
PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
generator=torch.manual_seed(SEED),
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time_ns()
mem_bytes = torch.cuda.max_memory_allocated()
mem_MB = int(mem_bytes / (10**6))
total_time = f"{(end - start) / 1e6:.1f}"
results = {
"pipeline_id": args.pipeline_id,
"total_time (ms)": total_time,
"memory (mb)": mem_MB,
}
if args.vae_path is not None:
results.update({"vae_path": args.vae_path})
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pipeline_id",
type=str,
default="CompVis/stable-diffusion-v1-4",
required=True,
)
parser.add_argument("--vae_path", type=str, default=None)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
results = run_inference(args)
print(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment