Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active April 18, 2024 02:37
Show Gist options
  • Save laksjdjf/5b032c4df66be910aa605b89d3234934 to your computer and use it in GitHub Desktop.
Save laksjdjf/5b032c4df66be910aa605b89d3234934 to your computer and use it in GitHub Desktop.
def make_unet_conversion_map():
unet_conversion_map_layer = []
# unet
# https://github.com/kohya-ss/sd-scripts/blob/2d7389185c021bc527b414563c245c5489d6328a/library/sdxl_model_util.py#L293
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}"
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1"
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}"
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1"
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv"
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op"
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0"
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}" # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0"
sd_mid_atn_prefix = "middle_block.1"
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}"
sd_time_embed_prefix = f"time_embed.{j*2}"
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}"
sd_label_embed_prefix = f"label_emb.0.{j*2}"
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0", "conv_in"))
# controlnet
# created by chatgpt
mapping_dict = {
"input_hint_block.0": "controlnet_cond_embedding.conv_in",
# 以下、input_hint_blockの残りのマッピングを定義
}
# input_hint_blockのマッピングを追加
orig_index = 2 # 既に0番目は上で定義されているため2から開始
diffusers_index = 0
while diffusers_index < 6:
mapping_dict[f"input_hint_block.{orig_index}"] = f"controlnet_cond_embedding.blocks.{diffusers_index}"
diffusers_index += 1
orig_index += 2
# 最後のconv_outのマッピングを追加
mapping_dict[f"input_hint_block.{orig_index}"] = "controlnet_cond_embedding.conv_out"
# down blocksとmid blockのマッピングを追加
num_input_blocks = 12
for i in range(num_input_blocks):
mapping_dict[f"zero_convs.{i}.0"] = f"controlnet_down_blocks.{i}"
mapping_dict["middle_block_out.0"] = "controlnet_mid_block"
mapping_dict.update({t[0]:t[1] for t in unet_conversion_map})
return mapping_dict
def convert_key(key, mapping_dict):
new_key = key
for k,v in mapping_dict.items():
new_key = new_key.replace(v, k) # diff to sgm
return new_key
def get_weight(down, up):
in_channel = down.shape[1]
rank = down.shape[0]
out_channel = up.shape[0]
kernel = () if down.dim() == 2 else down.shape[2:]
shape = (out_channel, in_channel) + kernel
down = down.reshape(rank, -1).to("cuda")
up = up.reshape(-1, rank).to("cuda")
weight = up @ down
weight = weight.reshape(shape)
return weight.to("cpu")
def merge_lora(controlnet, lora_weights):
mapping_dict = make_unet_conversion_map()
for name, modules in controlnet.named_modules():
sgm_key = convert_key(name, mapping_dict)
if sgm_key + ".down" in lora_weights:
weight = get_weight(lora_weights[sgm_key + ".down"], lora_weights[sgm_key + ".up"])
modules.weight.data = modules.weight.data + weight.to(modules.weight.data)
if sgm_key + ".weight" in lora_weights:
weight = lora_weights[sgm_key + ".weight"]
modules.weight.data = weight.to(modules.weight.data)
if sgm_key + ".bias" in lora_weights:
weight = lora_weights[sgm_key + ".bias"]
modules.bias.data = weight.to(modules.bias.data)
if __name__ == "__main__":
from diffusers import StableDiffusionXLControlNetPipeline, AutoencoderKL, ControlNetModel, UNet2DConditionModel
from diffusers.utils import load_image
from safetensors.torch import load_file, save_file
import torch
import numpy as np
import cv2
from PIL import Image
# load unet
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet").to("cuda", torch.float16)
# load controlnet
controlnet = ControlNetModel.from_unet(unet).to("cuda", dtype=torch.float16)
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) # avoid bug
lora_weights = load_file("control-lora-canny-rank128.safetensors")
merge_lora(controlnet, lora_weights)
# load pipe
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.to("cuda")
# https://huggingface.co/docs/diffusers/api/pipelines/controlnet_sdxl
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = "low quality, bad quality, sketches"
# download an image
image = load_image(
"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
# generate image
image = pipe(
prompt=prompt, negative_prompt=negative_prompt, controlnet_conditioning_scale=1.0, image=canny_image
).images[0]
image.save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment