inference.py
produces:
You're welcome to try out other quantization techniques from torchao
and benefit from torch.compile()
. diffusers-torchao
provides a handy reference.
Library versions:
diffusers
: Installed from themain
.torchao
: Installed from themain
.torch
: 2.6.0.dev20241027+cu121
Tested on H100.
serialized_inputs.pt
in aot_compile_with_int8_quant.py
was obtained by serializing the inputs to
self.transformer
(from here). You can download it from
here.
Additionally, to perform inference with this AoT compiled binary with DiffusionPipeline
as shown in inference.py
, the following changes are needed to the pipeline_flux.py
file:
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 040d935f1..f24cd28c5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -680,7 +680,7 @@ class FluxPipeline(
)
# 4. Prepare latent variables
- num_channels_latents = self.transformer.config.in_channels // 4
+ num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.transformer, torch.nn.Module) else 16
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -714,7 +714,7 @@ class FluxPipeline(
self._num_timesteps = len(timesteps)
# handle guidance
- if self.transformer.config.guidance_embeds:
+ if (isinstance(self.transformer, torch.nn.Module) and self.transformer.config.guidance_embeds) or isinstance(self.transformer, Callable):
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
The compiled binary file ("bs_1_1024.pt2") used in inference.py
can be found here.
Thanks to PyTorch folks (especially @jerryzh168) who provided guidance in this thread.