Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active November 3, 2024 00:22
Show Gist options
  • Save sayakpaul/de0eeeb6d08ba30a37dcf0bc9dacc5c5 to your computer and use it in GitHub Desktop.
Save sayakpaul/de0eeeb6d08ba30a37dcf0bc9dacc5c5 to your computer and use it in GitHub Desktop.
Shows how to AoT compile the Flux.1 Dev Transformer with int8 quant and perform inference.
import torch
from diffusers import FluxTransformer2DModel
import torch.utils.benchmark as benchmark
from torchao.quantization import quantize_, int8_weight_only
from torchao.utils import unwrap_tensor_subclass
import torch._inductor
torch._inductor.config.mixed_mm_choice = "triton"
def get_example_inputs():
example_inputs = torch.load("serialized_inputs.pt", weights_only=True)
example_inputs = {k: v.to("cuda") for k, v in example_inputs.items()}
example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})
return example_inputs
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
@torch.no_grad()
def load_model():
model = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
).to("cuda")
return model
def aot_compile(name, model, **sample_kwargs):
path = f"./{name}.pt2"
options = {
"max_autotune": True,
"triton.cudagraphs": True,
}
return torch._inductor.aoti_compile_and_package(
torch.export.export(model, (), sample_kwargs),
(),
sample_kwargs,
package_path=path,
inductor_configs=options,
)
def aot_load(path):
return torch._inductor.aoti_load_package(path)
@torch.no_grad()
def f(model, **kwargs):
return model(**kwargs)
if __name__ == "__main__":
model = load_model()
quantize_(model, int8_weight_only())
inputs1 = get_example_inputs()
unwrap_tensor_subclass(model)
path = aot_compile("bs_1_1024", model, **inputs1)
print(f"AoT compiled path {path}")
compiled_func = aot_load(path)
print(f"{compiled_func(**inputs1)[0].shape=}")
for _ in range(5):
_ = compiled_func(**inputs1)[0]
time = benchmark_fn(f, compiled_func, **inputs1)
print(f"{time=} seconds.")
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=None,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer = torch._inductor.aoti_load_package("./bs_1_1024.pt2")
image = pipeline("cute dog", guidance_scale=3.5, max_sequence_length=512, num_inference_steps=50).images[0]
image.save("aot_compiled.png")

inference.py produces:

image

You're welcome to try out other quantization techniques from torchao and benefit from torch.compile(). diffusers-torchao provides a handy reference.

Library versions:

  • diffusers: Installed from the main.
  • torchao: Installed from the main.
  • torch: 2.6.0.dev20241027+cu121

Tested on H100.

serialized_inputs.pt in aot_compile_with_int8_quant.py was obtained by serializing the inputs to self.transformer (from here). You can download it from here.

Additionally, to perform inference with this AoT compiled binary with DiffusionPipeline as shown in inference.py, the following changes are needed to the pipeline_flux.py file:

diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 040d935f1..f24cd28c5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -680,7 +680,7 @@ class FluxPipeline(
         )
 
         # 4. Prepare latent variables
-        num_channels_latents = self.transformer.config.in_channels // 4
+        num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.transformer, torch.nn.Module) else 16
         latents, latent_image_ids = self.prepare_latents(
             batch_size * num_images_per_prompt,
             num_channels_latents,
@@ -714,7 +714,7 @@ class FluxPipeline(
         self._num_timesteps = len(timesteps)
 
         # handle guidance
-        if self.transformer.config.guidance_embeds:
+        if (isinstance(self.transformer, torch.nn.Module) and self.transformer.config.guidance_embeds) or isinstance(self.transformer, Callable):
             guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
             guidance = guidance.expand(latents.shape[0])
         else:

The compiled binary file ("bs_1_1024.pt2") used in inference.py can be found here.

Thanks to PyTorch folks (especially @jerryzh168) who provided guidance in this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment