Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created June 5, 2024 03:55
Show Gist options
  • Save sayakpaul/3154605f6af05b98a41081aaba5ca43e to your computer and use it in GitHub Desktop.
Save sayakpaul/3154605f6af05b98a41081aaba5ca43e to your computer and use it in GitHub Desktop.
Run `HunyuanDiTPipeline` from Diffusers under 6GBs of GPU VRAM.
"""
Make sure you have `diffusers`, `accelerate`, `transformers`, and `bitsandbytes` installed.
You also set up PyTorch and CUDA.
Once the dependencies are installed, you can run `python run_hunyuan_dit_less_memory.py`.
"""
from diffusers import HunyuanDiTPipeline
from transformers import T5EncoderModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
id = "Tencent-Hunyuan/HunyuanDiT-Diffusers"
text_encoder_2 = T5EncoderModel.from_pretrained(
id,
subfolder="text_encoder_2",
load_in_8bit=True,
device_map="auto",
)
pipeline = HunyuanDiTPipeline.from_pretrained(
id,
text_encoder_2=text_encoder_2,
transformer=None,
vae=None,
torch_dtype=torch.float16,
device_map="balanced",
)
with torch.no_grad():
prompt = "一个宇航员在骑马"
prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = pipeline.encode_prompt(prompt)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = pipeline.encode_prompt(
prompt=prompt,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
max_sequence_length=256,
text_encoder_index=1,
)
del text_encoder_2
del pipeline
flush()
pipe = HunyuanDiTPipeline.from_pretrained(
id,
text_encoder=None,
text_encoder_2=None,
torch_dtype=torch.float16,
).to("cuda")
image = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_2=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_2=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask,
prompt_attention_mask_2=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask,
negative_prompt_attention_mask_2=negative_prompt_attention_mask_2,
num_images_per_prompt=1,
).images[0]
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
image.save("memory_optimized.png")
@xhoxye
Copy link

xhoxye commented Jun 7, 2024

Yes, I installed the diffuser from the source

pip install diffusers
pip list
Package            Version
------------------ ------------
accelerate         0.30.1
bitsandbytes       0.43.1
diffusers          0.28.2

@sayakpaul
Copy link
Author

Then the diffusers version should have a “dev” appended to it.

The following is the right way to install it from the source:

pip install git+ https://github.com/huggingface/diffusers

@xhoxye
Copy link

xhoxye commented Jun 7, 2024

Thank you so much
pip install git+https://github.com/huggingface/diffusers
Successfully installed diffusers-0.29.0.dev0
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:31<00:00, 1.83s/it]
Max memory allocated: 5.386903762817383 GB

@s9anus98a
Copy link

can we fix hunyuan dit by using clip merge to SDXL may be? or latent or refiner? or PAG or AYS? which one is possible to make it better?

@sayakpaul
Copy link
Author

Thanks for your doubts. It is perhaps better to open a discussion on the Diffusers repository to discuss this.

@metercai
Copy link

This code occasionally reports errors in line 47:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

I think the line 36 should to be change to device_map="auto",, but The feedback after modification is: NotImplementedError: auto not supported. Supported strategies are: balanced

what should I do to sovle it?

@sayakpaul
Copy link
Author

Need a deterministic reproducible code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment