Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created July 2, 2024 08:59
Show Gist options
  • Save sayakpaul/4d930d0d98d301b35d15a2af17befd06 to your computer and use it in GitHub Desktop.
Save sayakpaul/4d930d0d98d301b35d15a2af17befd06 to your computer and use it in GitHub Desktop.
Benchmarks the "ptx0/pixart-900m-1024-ft" model with `torch.compile()`.
import torch
torch.set_float32_matmul_precision("high")
from diffusers import DiffusionPipeline
import time
pipeline_id = "ptx0/pixart-900m-1024-ft"
pipeline = DiffusionPipeline.from_pretrained(
pipeline_id,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
prompt = "a photo of a cat"
for _ in range(3):
_ = pipeline(
prompt=prompt,
num_inference_steps=30,
guidance_scale=7.5,
generator=torch.manual_seed(1),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt=prompt,
num_inference_steps=30,
guidance_scale=7.5,
generator=torch.manual_seed(1),
)
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment