Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created August 21, 2023 15:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lyken17/ae127d90f4b8cda32a599e921f458002 to your computer and use it in GitHub Desktop.
Save Lyken17/ae127d90f4b8cda32a599e921f458002 to your computer and use it in GitHub Desktop.
import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import LoRAAttnProcessor
from utils import print_gpu_utilization
# pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
)
device = "cuda"
unet = pipe.unet
vae = pipe.vae
text_encoder = pipe.text_encoder
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.enable_xformers_memory_efficient_attention()
set_lora = True
if not set_lora:
unet.requires_grad_(True)
else:
print("Setup LoRA")
# https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py#L461
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=4,
)
unet.set_attn_processor(lora_attn_procs)
pipe = pipe.to(device)
bs = 8
rs = 512
data = torch.randn(bs, 4, rs // 8, rs // 8).to(device).half()
tstamp = torch.randn(bs,).to(device).half()
text = torch.randn(bs, 3, 768).to(device).half()
with torch.no_grad():
out = unet(data, tstamp, text, added_cond_kwargs={}).sample
print_gpu_utilization(0, prefix="forward")
out = unet(data, tstamp, text, added_cond_kwargs={}).sample
out.sum().backward()
print_gpu_utilization(0, prefix="finetune")
exit(0)
import torch
rs = 512
bs = 1
data = torch.randn(bs, 4, rs, rs)
unet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment