Skip to content

Instantly share code, notes, and snippets.

@lucataco
Created September 11, 2023 19:58
Show Gist options
  • Save lucataco/bbd420ab927fe2cfb8d8631fc880e07e to your computer and use it in GitHub Desktop.
Save lucataco/bbd420ab927fe2cfb8d8631fc880e07e to your computer and use it in GitHub Desktop.
Replicate-LoRA-manual-load-weights-take2
import os
import json
import torch
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
from safetensors import safe_open
from dataset_and_utils import TokenEmbeddingsHandler
from safetensors.torch import load_file
from diffusers.models.attention_processor import LoRAAttnProcessor2_0
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
# Load tensor weights
tensors = load_file("weights/lora.safetensors")
unet = pipe.unet
unet_lora_attn_procs = {}
name_rank_map = {}
for tk, tv in tensors.items():
# up is N, d
if tk.endswith("up.weight"):
proc_name = ".".join(tk.split(".")[:-3])
r = tv.shape[1]
name_rank_map[proc_name] = r
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
module = LoRAAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=name_rank_map[name],
)
unet_lora_attn_procs[name] = module.to("cuda")
unet.set_attn_processor(unet_lora_attn_procs)
unet.load_state_dict(tensors, strict=False)
# load text
handler = TokenEmbeddingsHandler(
[pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2]
)
handler.load_embeddings("weights/embeddings.pti")
# load params
with open("weights/special_params.json", "r") as f:
params = json.load(f)
token_map = params
sdxl_kwargs = {}
sdxl_kwargs["width"] = 1024
sdxl_kwargs["height"] = 1024
sdxl_kwargs["cross_attention_kwargs"] = {"scale": 0.6}
# K_EULER Scheduler
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
seed=57727
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator=torch.Generator("cuda").manual_seed(seed)
prompt="A TOK emoji of a man"
for k, v in token_map.items():
prompt = prompt.replace(k, v)
common_args = {
"prompt": prompt,
"guidance_scale": 7.5,
"generator": generator,
"num_inference_steps": 50,
}
output = pipe(**common_args, **sdxl_kwargs).images[0]
output.save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment