Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created April 24, 2023 08:19
Show Gist options
  • Save sayakpaul/27aec6bca7eb7b0e0aa4112205850335 to your computer and use it in GitHub Desktop.
Save sayakpaul/27aec6bca7eb7b0e0aa4112205850335 to your computer and use it in GitHub Desktop.
import tomesd
import torch
import torch.utils.benchmark as benchmark
from diffusers import StableDiffusionPipeline
def benchmark_torch_function(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return round(t0.blocked_autorange(min_run_time=1).mean, 2)
model_id = "runwayml/stable-diffusion-v1-5"
prompt = "a photo of an astronaut riding a horse on mars"
steps = 25
num_images_per_prompt = 1
dtype = torch.float16
resolution = 1024
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=dtype, safety_checker=None
).to("cuda")
pipe.set_progress_bar_config(disable=True)
# Vanilla
print("Running benchmark with vanilla pipeline...")
f = lambda: pipe(
prompt,
height=resolution,
width=resolution,
num_inference_steps=steps,
num_images_per_prompt=num_images_per_prompt,
).images
time_vanilla = benchmark_torch_function(f)
# With ToMe
print("Running benchmark with ToMe patched pipeline...")
tomesd.apply_patch(pipe, ratio=0.5)
f = lambda: pipe(
prompt,
height=resolution,
width=resolution,
num_inference_steps=steps,
num_images_per_prompt=num_images_per_prompt,
).images
time_tome = benchmark_torch_function(f)
# With ToMe + xformers
print("Running benchmark with ToMe patched + xformers enabled pipeline...")
tomesd.remove_patch(pipe)
pipe.enable_xformers_memory_efficient_attention()
tomesd.apply_patch(pipe, ratio=0.5)
f = lambda: pipe(
prompt,
height=resolution,
width=resolution,
num_inference_steps=steps,
num_images_per_prompt=num_images_per_prompt,
).images
time_tome_xformers = benchmark_torch_function(f)
print(
f"Model: {model_id}, dtype: {dtype}, steps: {steps}, num_images_per_prompt: {num_images_per_prompt}, resolution: {resolution} x {resolution}"
)
print(f"Vanilla : {time_vanilla} s")
print(f"ToMe : {time_tome} s")
print(f"ToMe + xformers: {time_tome_xformers} s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment