Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active August 14, 2024 21:30
Show Gist options
  • Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
The script shows how to run SD3 with `torch.compile()`
import torch
torch.set_float32_matmul_precision("high")
from diffusers import StableDiffusion3Pipeline
import time
id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipeline = StableDiffusion3Pipeline.from_pretrained(
id,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
prompt = "a photo of a cat"
for _ in range(3):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
image = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
).images[0]
filename = "_".join(prompt.split(" "))
image.save(f"diffusers_{filename}.png")
@sayakpaul
Copy link
Author

First time when you run compilation, it will be slow and the subsequent runs will be faster.

Sorry, won’t have the time to test on WSL.

@gkalstn000
Copy link

Same issue when apply below code.
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

It's ok for vae compiler.

i installed latest diffusers, huggingface, transformers, pytorch


python main.py -
/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: Transformer2DModelOutput is deprecated and will be removed in version 1.0.0. Importing Transformer2DModelOutput from diffusers.models.transformer_2d is deprecated and this will be removed in a future version. Please use from diffusers.models.modeling_outputs import Transformer2DModelOutput, instead.
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
INFO: Started server process [299818]
INFO: Waiting for application startup.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.05it/s]
Loading pipeline components...: 56%|██████████████████████████████████████████████████████▍ | 5/9 [00:01<00:01, 2.82it/s]You set add_prefix_space. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00, 2.85it/s]
pipeline setting done!
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:7860 (Press CTRL+C to quit)
0%| | 0/20 [00:00<?, ?it/s]
INFO: 127.0.0.1:34250 - "POST /generate_image/ HTTP/1.1" 500 Internal Server Error
ERROR: Exception in ASGI application
Traceback (most recent call last):
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
result = await app( # type: ignore[func-returns-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in call
return await self.app(scope, receive, send)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
await super().call(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/applications.py", line 123, in call
await self.middleware_stack(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/errors.py", line 186, in call
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/errors.py", line 164, in call
await self.app(scope, receive, _send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 65, in call
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
await app(scope, receive, sender)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 756, in call
await self.middleware_stack(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 776, in app
await route.handle(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 297, in handle
await self.app(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 77, in app
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
await app(scope, receive, sender)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 72, in app
response = await func(request)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/routing.py", line 278, in app
raw_response = await run_endpoint_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/routing.py", line 191, in run_endpoint_function
return await dependant.call(**values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/jector_client_ai/main.py", line 49, in generate_image
image = run_t2i_model(pipeline, request_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/jector_client_ai/run_generator.py", line 56, in run_t2i_model
image = pipeline(prompt=request_data.prompt,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 828, in call
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
self.call_function(fn, args, kwargs)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 2329, in inline_call
result = InliningInstructionTranslator.check_inlineable(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2306, in check_inlineable
unimplemented(
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: Logger.warning | warning /home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/logging/init.py, skipped according trace_rules.lookup'

from user code:
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/models/transformers/transformer_sd3.py", line 285, in forward
logger.warning(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

@sayakpaul
Copy link
Author

I see some HTTP allocation traces in the logs. I am still unable reproduce the provided snippet on my setup.

@gkalstn000
Copy link

Yes, I'm creating a text-to-image API using FastAPI with SD3, so there are HTTP-related logs.

I created an Ubuntu x86 L4 instance on GCP and installed the Nvidia driver.

I also installed GCC-related libraries.

sudo apt-get update
sudo apt-get install build-essential
export CC=gcc

Then, I installed PyTorch using the official Conda installation code.

conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
I installed diffusers[torch], xformers, transformers, etc., without specifying a specific version.

Then, I initialized the SD3 pipeline with the following code:

torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipeline = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16
).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)

# pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

When generating the image as shown below, the above error occurred:

image = pipeline(prompt=request_data.prompt,
                 negative_prompt='worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch',
                 height=1024,
                 width=1024,
                 num_inference_steps=20,
                 guidance_scale=7,
                 ).images[0]

Could you please check if I installed anything incorrectly?

@gkalstn000
Copy link

I solved the issue by installing peft:
pip install peft

I'm not sure what the main problem was exactly, but the error was caused here:

diffusers/models/transformers/transformer_sd3.py", line 285, in forward
logger.warning(
    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

That logger caused the error when torch.compile was applied to the transformers.

It makes 19.8% faster in 1024x1024 resolution
baseline : 12.2532 sec
compile : 9.82578 sec

@johnzhangzzzz
Copy link

hello,I run this program in docker(torch=2.1) and meet a problem that:

Traceback (most recent call last):
File "/workspace/std3/a.py", line 18, in
torch._inductor.config.coordinate_descent_check_all_directions = True
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/config_utils.py", line 71, in setattr
raise AttributeError(f"{self.name}.{name} does not exist")
AttributeError: torch._inductor.config.coordinate_descent_check_all_directions does not exist

@sayakpaul
Copy link
Author

You should use PyTorch 2.3

@johnzhangzzzz
Copy link

You should use PyTorch 2.3

I solved this problem with pytorch2.3,thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment