Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active August 24, 2023 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/af6428ece8e2a95493d228bd9b324557 to your computer and use it in GitHub Desktop.
Save sayakpaul/af6428ece8e2a95493d228bd9b324557 to your computer and use it in GitHub Desktop.
import argparse
import time
import torch
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
from diffusers.utils import load_image
PIPELINE_ID = "stabilityai/stable-diffusion-xl-base-1.0"
VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
LORA_ID = "sd_xl_offset_example-lora_1.0.safetensors"
NUM_ITERS_TO_RUN = 3
NUM_INFERENCE_STEPS = 25
NUM_IMAGES_PER_PROMPT = 4
PROMPT = "beautiful scenery nature glass bottle landscape, , purple galaxy bottle"
def load_pipeline(fuse=False):
vae = AutoencoderKL.from_pretrained(VAE_PATH, torch_dtype=torch.float16).to("cuda")
pipe = StableDiffusionXLPipeline.from_pretrained(
PIPELINE_ID,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipe.load_lora_weights(PIPELINE_ID, weight_name=LORA_ID)
if fuse:
pipe.unet.fuse_lora()
return pipe
def run_inference(args):
torch.cuda.reset_peak_memory_stats()
pipe = load_pipeline(args.fuse)
start = time.time_ns()
for _ in range(NUM_ITERS_TO_RUN):
images = pipe(
PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time_ns()
mem_bytes = torch.cuda.max_memory_allocated()
mem_MB = int(mem_bytes / (10**6))
total_time = f"{(end - start) / 1e6:.1f}"
results = {
"fuse": args.fuse,
"total_time (ms)": total_time,
"memory (mb)": mem_MB,
}
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--fuse",
default=False,
action="store_true",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
results = run_inference(args)
print(results)
@sayakpaul
Copy link
Author

On V100:

{'fuse': False, 'total_time (ms)': '95874.1', 'memory (mb)': 13572}

{'fuse': True, 'total_time (ms)': '83744.8', 'memory (mb)': 13543}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment