-
-
Save sayakpaul/3ae0f847001d342af27018a96f467e4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install -U accelerate transformers bitsandbytes | |
# pip install -U git+https://github.com/huggingface/diffusers | |
from transformers import T5EncoderModel | |
from diffusers import PixArtAlphaPipeline | |
import torch | |
import gc | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def bytes_to_giga_bytes(bytes): | |
return bytes / 1024 / 1024 / 1024 | |
# Loading in 8 bits needs `bitsandbytes`. | |
text_encoder = T5EncoderModel.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
subfolder="text_encoder", | |
load_in_8bit=True, | |
device_map="auto", | |
) | |
pipe = PixArtAlphaPipeline.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
text_encoder=text_encoder, | |
transformer=None, | |
device_map="auto" | |
) | |
with torch.no_grad(): | |
prompt = "cute cat" | |
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt) | |
del text_encoder | |
del pipe | |
flush() | |
pipe = PixArtAlphaPipeline.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
text_encoder=None, | |
torch_dtype=torch.float16, | |
).to("cuda") | |
latents = pipe( | |
negative_prompt=None, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_embeds, | |
prompt_attention_mask=prompt_attention_mask, | |
negative_prompt_attention_mask=negative_prompt_attention_mask, | |
num_images_per_prompt=1, | |
output_type="latent", | |
).images | |
del pipe.transformer | |
flush() | |
with torch.no_grad(): | |
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
image = pipe.image_processor.postprocess(image, output_type="pil") | |
image[0].save("cat.png") | |
print( | |
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment