Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active November 30, 2023 11:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/2e4534b205841ccce085400b8c42da85 to your computer and use it in GitHub Desktop.
Save sayakpaul/2e4534b205841ccce085400b8c42da85 to your computer and use it in GitHub Desktop.
Benchmarking script for Stable Diffusion on TensorRT
import argparse
import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
CKPT = "CompVis/stable-diffusion-v1-4"
PROMPT = "a majestic castle in the clouds"
def load_pipeline(run_compile=False, with_tensorrt=False):
pipe = DiffusionPipeline.from_pretrained(
CKPT, torch_dtype=torch.float16, use_safetensors=True
)
pipe = pipe.to("cuda")
if run_compile and not with_tensorrt:
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
elif run_compile and with_tensorrt:
print("Run torch compile with TensorRT backend")
pipe.unet = torch.compile(
pipe.unet,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"precision": torch.float16,
},
dynamic=False,
)
pipe.set_progress_bar_config(disable=True)
return pipe
def run_inference(pipe, batch_size=1):
_ = pipe(PROMPT)
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--run_compile", action="store_true")
parser.add_argument("--with_tensorrt", action="store_true")
args = parser.parse_args()
pipeline = load_pipeline(
run_compile=args.run_compile, with_tensorrt=args.with_tensorrt
)
time = benchmark_fn(run_inference, pipeline, args.batch_size)
print(
f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {time/1e6:.3f} seconds"
)
import argparse
import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
CKPT = "stabilityai/stable-diffusion-xl-base-1.0"
PROMPT = "a majestic castle in the clouds"
def load_pipeline(run_compile=False, with_tensorrt=False):
pipe = DiffusionPipeline.from_pretrained(
CKPT, torch_dtype=torch.float16, use_safetensors=True
)
pipe = pipe.to("cuda")
if run_compile and not with_tensorrt:
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
elif run_compile and with_tensorrt:
print("Run torch compile with TensorRT backend")
pipe.unet = torch.compile(
pipe.unet,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"precision": torch.float16,
},
dynamic=False,
)
pipe.set_progress_bar_config(disable=True)
return pipe
def run_inference(pipe, batch_size=1):
_ = pipe(PROMPT)
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--run_compile", action="store_true")
parser.add_argument("--with_tensorrt", action="store_true")
args = parser.parse_args()
pipeline = load_pipeline(
run_compile=args.run_compile, with_tensorrt=args.with_tensorrt
)
time = benchmark_fn(run_inference, pipeline, args.batch_size)
print(
f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {time/1e6:.3f} seconds"
)

Benchmarking done using ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.1 and the following dependencies:

  1. accelerate
  2. transformers==4.33.2
  3. diffusers==0.21.4

Timing:

With compilation: False, and TensorRT: False in 3.767 seconds
With compilation: True, and TensorRT: False in 3.045 seconds
With compilation: True, and TensorRT: True in 1.157 seconds

Timing for SDXL:

With compilation: False, and TensorRT: False in 6.713 seconds
With compilation: True, and TensorRT: False in 6.417 seconds
With compilation: True, and TensorRT: True in 5.537 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment