Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active June 17, 2024 07:47
Show Gist options
  • Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
Save sayakpaul/508d89d7aad4f454900813da5d42ca97 to your computer and use it in GitHub Desktop.
The script shows how to run SD3 with `torch.compile()`
import torch
torch.set_float32_matmul_precision("high")
from diffusers import StableDiffusion3Pipeline
import time
id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipeline = StableDiffusion3Pipeline.from_pretrained(
id,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
prompt = "a photo of a cat"
for _ in range(3):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
)
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
image = pipeline(
prompt=prompt,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.manual_seed(1),
).images[0]
filename = "_".join(prompt.split(" "))
image.save(f"diffusers_{filename}.png")
@dhcracchiolo
Copy link

Excellent! Could you please post your python version and any other relevant versions?

So far I've been unable to run this script:

  File "/home/ubuntu/miniconda3/envs/image_server/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2306, in check_inlineable
    unimplemented(
  File "/home/ubuntu/miniconda3/envs/image_server/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: Logger.warning | warning /home/ubuntu/miniconda3/envs/image_server/lib/python3.9/logging/__init__.py, skipped according trace_rules.lookup'

from user code:
   File "/home/ubuntu/miniconda3/envs/image_server/lib/python3.9/site-packages/diffusers/models/transformers/transformer_sd3.py", line 285, in forward
    logger.warning(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
    ```

@sayakpaul
Copy link
Author

It can be fixed with PyTorch 2.3.

@xxwizardxx117
Copy link

how? i am using torch 2.3.1 and cuda 121 and have tried using 118 also

@sayakpaul
Copy link
Author

Then I am not sure what is going on with your setup. I am unable to reproduce it on PyTorch 2.3 and CUDA 12.2.

@xxwizardxx117
Copy link

error encountered :
Exception in Tkinter callback
Traceback (most recent call last):
File "S:\devsetup\envs\sd3\lib\tkinter_init_.py", line 1921, in call
return self.func(*args)
File "S:\devsetup\envs\sd3\lib\site-packages\customtkinter\widgets\ctk_button.py", line 377, in clicked
self.command()
File "c:\Users\sharm\Desktop\minor\app.py", line 106, in generate
image = pipe(prompt.get(), guidance_scale=7.0, num_inference_steps=20,height=256,
File "S:\devsetup\envs\sd3\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\diffusers\pipelines\stable_diffusion_3\pipeline_stable_diffusion_3.py", line 828, in call
noise_pred = self.transformer(
File "S:\devsetup\envs\sd3\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "S:\devsetup\envs\sd3\lib\contextlib.py", line 79, in inner
return func(*args, **kwds)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 676, in compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 165, in fn
return fn(*args, **kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\convert_frame.py", line 500, in transform
tracer.run()
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2149, in run
super().run()
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 810, in run
and self.step()
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1219, in CALL_FUNCTION
self.call_function(fn, args, {})
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\variables\functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\variables\functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\variables\functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2285, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2329, in inline_call

result = InliningInstructionTranslator.check_inlineable(func)
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2306, in check_inlineable
unimplemented(
File "S:\devsetup\envs\sd3\lib\site-packages\torch_dynamo\exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch.dynamo.exc.Unsupported: 'inline in skipfiles: Logger.warning | warning S:\devsetup\envs\sd3\lib\logging_init.py, skipped according trace_rules.lookup'

from user code:
File "S:\devsetup\envs\sd3\lib\site-packages\diffusers\models\transformers\transformer_sd3.py", line 285, in forward
logger.warning(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

code where its happening mostlikely as designed interface opens up:
def generate():     

Generate the image

with autocast(device):  

    image = pipe(prompt.get(), guidance_scale=7.0, num_inference_steps=20,height=256,
    width=256,
   )["images"][0]  
    # height=1024,              
    # width=1024,

image.save('generated_image.png') 
original_img = Image.open('generated_image.png').convert('RGB')# modify 
resized_img = original_img.resize((512,440), Image.Resampling.LANCZOS)

# Convert the Image object to a PhotoImage object
img = ImageTk.PhotoImage(resized_img)

img_ref.img = img #keep a reference to the image
img_ref.configure(image=img) 

custom button code used :

Button to trigger image generation

trigger = ctk.CTkButton(height=40, width=120, text_font=("Arial", 20), text_color="white", fg_color="blue", command=generate)
trigger.configure(text="Generate")
trigger.place(x=240, y= 60) # generate button
app.bind("", lambda event= None: generate()) # set enter key to the generate

current dependencies
accelerate==0.31.0
bitsandbytes==0.43.1
certifi==2024.6.2
charset-normalizer==3.3.2
colorama==0.4.6
customtkinter==4.6.2
darkdetect==0.8.0
diffuser==0.0.1
diffusers==0.29.0
filelock==3.15.1
fsspec==2024.6.0
huggingface-hub==0.23.4
idna==3.7
importlib_metadata==7.1.0
intel-openmp==2021.4.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mkl==2021.4.0
mpmath==1.3.0
networkx==3.3
numpy==1.26.4
packaging==24.1
pillow==10.3.0
protobuf==5.27.1
psutil==5.9.8
PyYAML==6.0.1
regex==2024.5.15
requests==2.32.3
safetensors==0.4.3
sentencepiece==0.2.0
sympy==1.12.1
tbb==2021.12.0
tk==0.1.0
tokenizers==0.19.1
torch==2.3.1+cu121
torchaudio==2.3.1+cu121
torchvision==0.18.1+cu121
tqdm==4.66.4
transformers==4.41.2
typing_extensions==4.12.2
urllib3==2.2.1
zipp==3.19.2

@sayakpaul
Copy link
Author

Seems like a Windows bug. I am still unable to reproduce on Linux.

@xxwizardxx117
Copy link

does the compilation actually become faster after using it
like when i do 50 frame compute it takes about 9 min for a 512 * 1024 image
and opening and closing time for interface on gpu is also very high

if possible can you guide a way to use my code on wsl2 linux
i have that installed on my laptop
code only has one file that is app.py

@sayakpaul
Copy link
Author

First time when you run compilation, it will be slow and the subsequent runs will be faster.

Sorry, won’t have the time to test on WSL.

@gkalstn000
Copy link

Same issue when apply below code.
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

It's ok for vae compiler.

i installed latest diffusers, huggingface, transformers, pytorch


python main.py -
/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: Transformer2DModelOutput is deprecated and will be removed in version 1.0.0. Importing Transformer2DModelOutput from diffusers.models.transformer_2d is deprecated and this will be removed in a future version. Please use from diffusers.models.modeling_outputs import Transformer2DModelOutput, instead.
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
INFO: Started server process [299818]
INFO: Waiting for application startup.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.05it/s]
Loading pipeline components...: 56%|██████████████████████████████████████████████████████▍ | 5/9 [00:01<00:01, 2.82it/s]You set add_prefix_space. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00, 2.85it/s]
pipeline setting done!
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:7860 (Press CTRL+C to quit)
0%| | 0/20 [00:00<?, ?it/s]
INFO: 127.0.0.1:34250 - "POST /generate_image/ HTTP/1.1" 500 Internal Server Error
ERROR: Exception in ASGI application
Traceback (most recent call last):
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
result = await app( # type: ignore[func-returns-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in call
return await self.app(scope, receive, send)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
await super().call(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/applications.py", line 123, in call
await self.middleware_stack(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/errors.py", line 186, in call
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/errors.py", line 164, in call
await self.app(scope, receive, _send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 65, in call
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
await app(scope, receive, sender)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 756, in call
await self.middleware_stack(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 776, in app
await route.handle(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 297, in handle
await self.app(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 77, in app
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
raise exc
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
await app(scope, receive, sender)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/starlette/routing.py", line 72, in app
response = await func(request)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/routing.py", line 278, in app
raw_response = await run_endpoint_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/fastapi/routing.py", line 191, in run_endpoint_function
return await dependant.call(**values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/jector_client_ai/main.py", line 49, in generate_image
image = run_t2i_model(pipeline, request_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/jector_client_ai/run_generator.py", line 56, in run_t2i_model
image = pipeline(prompt=request_data.prompt,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 828, in call
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
self.call_function(fn, args, kwargs)
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 2329, in inline_call
result = InliningInstructionTranslator.check_inlineable(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2306, in check_inlineable
unimplemented(
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: Logger.warning | warning /home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/logging/init.py, skipped according trace_rules.lookup'

from user code:
File "/home/gkalstn000/anaconda3/envs/jector_cli_ai/lib/python3.11/site-packages/diffusers/models/transformers/transformer_sd3.py", line 285, in forward
logger.warning(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

@sayakpaul
Copy link
Author

I see some HTTP allocation traces in the logs. I am still unable reproduce the provided snippet on my setup.

@gkalstn000
Copy link

Yes, I'm creating a text-to-image API using FastAPI with SD3, so there are HTTP-related logs.

I created an Ubuntu x86 L4 instance on GCP and installed the Nvidia driver.

I also installed GCC-related libraries.

sudo apt-get update
sudo apt-get install build-essential
export CC=gcc

Then, I installed PyTorch using the official Conda installation code.

conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
I installed diffusers[torch], xformers, transformers, etc., without specifying a specific version.

Then, I initialized the SD3 pipeline with the following code:

torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipeline = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16
).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)

# pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

When generating the image as shown below, the above error occurred:

image = pipeline(prompt=request_data.prompt,
                 negative_prompt='worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch',
                 height=1024,
                 width=1024,
                 num_inference_steps=20,
                 guidance_scale=7,
                 ).images[0]

Could you please check if I installed anything incorrectly?

@gkalstn000
Copy link

I solved the issue by installing peft:
pip install peft

I'm not sure what the main problem was exactly, but the error was caused here:

diffusers/models/transformers/transformer_sd3.py", line 285, in forward
logger.warning(
    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

That logger caused the error when torch.compile was applied to the transformers.

It makes 19.8% faster in 1024x1024 resolution
baseline : 12.2532 sec
compile : 9.82578 sec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment