Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active June 17, 2024 07:47
Show Gist options
  • Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
The script shows how to run SD3 with `torch.compile()`
import torch
torch.set_float32_matmul_precision("high")
from diffusers import StableDiffusion3Pipeline
import time
id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipeline = StableDiffusion3Pipeline.from_pretrained(
id,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
prompt = "a photo of a cat"
for _ in range(3):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
image = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
).images[0]
filename = "_".join(prompt.split(" "))
image.save(f"diffusers_{filename}.png")
@gkalstn000
Copy link

Yes, I'm creating a text-to-image API using FastAPI with SD3, so there are HTTP-related logs.

I created an Ubuntu x86 L4 instance on GCP and installed the Nvidia driver.

I also installed GCC-related libraries.

sudo apt-get update
sudo apt-get install build-essential
export CC=gcc

Then, I installed PyTorch using the official Conda installation code.

conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
I installed diffusers[torch], xformers, transformers, etc., without specifying a specific version.

Then, I initialized the SD3 pipeline with the following code:

torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipeline = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16
).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)

# pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

When generating the image as shown below, the above error occurred:

image = pipeline(prompt=request_data.prompt,
                 negative_prompt='worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch',
                 height=1024,
                 width=1024,
                 num_inference_steps=20,
                 guidance_scale=7,
                 ).images[0]

Could you please check if I installed anything incorrectly?

@gkalstn000
Copy link

I solved the issue by installing peft:
pip install peft

I'm not sure what the main problem was exactly, but the error was caused here:

diffusers/models/transformers/transformer_sd3.py", line 285, in forward
logger.warning(
    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

That logger caused the error when torch.compile was applied to the transformers.

It makes 19.8% faster in 1024x1024 resolution
baseline : 12.2532 sec
compile : 9.82578 sec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment