|
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. |
|
# *Only* converts the UNet, VAE, and Text Encoder. |
|
# Does not convert optimizer state or any other thing. |
|
# Written by jachiam |
|
|
|
import argparse |
|
import os.path as osp |
|
|
|
import torch |
|
|
|
|
|
# =================# |
|
# UNet Conversion # |
|
# =================# |
|
|
|
unet_conversion_map = [ |
|
# (stable-diffusion, HF Diffusers) |
|
("time_embed.0.weight", "time_embedding.linear_1.weight"), |
|
("time_embed.0.bias", "time_embedding.linear_1.bias"), |
|
("time_embed.2.weight", "time_embedding.linear_2.weight"), |
|
("time_embed.2.bias", "time_embedding.linear_2.bias"), |
|
("input_blocks.0.0.weight", "conv_in.weight"), |
|
("input_blocks.0.0.bias", "conv_in.bias"), |
|
("out.0.weight", "conv_norm_out.weight"), |
|
("out.0.bias", "conv_norm_out.bias"), |
|
("out.2.weight", "conv_out.weight"), |
|
("out.2.bias", "conv_out.bias"), |
|
] |
|
|
|
unet_conversion_map_resnet = [ |
|
# (stable-diffusion, HF Diffusers) |
|
("in_layers.0", "norm1"), |
|
("in_layers.2", "conv1"), |
|
("out_layers.0", "norm2"), |
|
("out_layers.3", "conv2"), |
|
("emb_layers.1", "time_emb_proj"), |
|
("skip_connection", "conv_shortcut"), |
|
] |
|
|
|
unet_conversion_map_layer = [] |
|
# hardcoded number of downblocks and resnets/attentions... |
|
# would need smarter logic for other networks. |
|
for i in range(4): |
|
# loop over downblocks/upblocks |
|
|
|
for j in range(2): |
|
# loop over resnets/attentions for downblocks |
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." |
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." |
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) |
|
|
|
if i < 3: |
|
# no attention layers in down_blocks.3 |
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." |
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." |
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) |
|
|
|
for j in range(3): |
|
# loop over resnets/attentions for upblocks |
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." |
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." |
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) |
|
|
|
if i > 0: |
|
# no attention layers in up_blocks.0 |
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." |
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." |
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) |
|
|
|
if i < 3: |
|
# no downsample in down_blocks.3 |
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." |
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." |
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) |
|
|
|
# no upsample in up_blocks.3 |
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." |
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." |
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) |
|
|
|
hf_mid_atn_prefix = "mid_block.attentions.0." |
|
sd_mid_atn_prefix = "middle_block.1." |
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) |
|
|
|
for j in range(2): |
|
hf_mid_res_prefix = f"mid_block.resnets.{j}." |
|
sd_mid_res_prefix = f"middle_block.{2*j}." |
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) |
|
|
|
|
|
def convert_unet_state_dict(unet_state_dict): |
|
# buyer beware: this is a *brittle* function, |
|
# and correct output requires that all of these pieces interact in |
|
# the exact order in which I have arranged them. |
|
mapping = {k: k for k in unet_state_dict.keys()} |
|
for sd_name, hf_name in unet_conversion_map: |
|
mapping[hf_name] = sd_name |
|
for k, v in mapping.items(): |
|
if "resnets" in k: |
|
for sd_part, hf_part in unet_conversion_map_resnet: |
|
v = v.replace(hf_part, sd_part) |
|
mapping[k] = v |
|
for k, v in mapping.items(): |
|
for sd_part, hf_part in unet_conversion_map_layer: |
|
v = v.replace(hf_part, sd_part) |
|
mapping[k] = v |
|
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} |
|
return new_state_dict |
|
|
|
|
|
# ================# |
|
# VAE Conversion # |
|
# ================# |
|
|
|
vae_conversion_map = [ |
|
# (stable-diffusion, HF Diffusers) |
|
("nin_shortcut", "conv_shortcut"), |
|
("norm_out", "conv_norm_out"), |
|
("mid.attn_1.", "mid_block.attentions.0."), |
|
] |
|
|
|
for i in range(4): |
|
# down_blocks have two resnets |
|
for j in range(2): |
|
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." |
|
sd_down_prefix = f"encoder.down.{i}.block.{j}." |
|
vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) |
|
|
|
if i < 3: |
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." |
|
sd_downsample_prefix = f"down.{i}.downsample." |
|
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) |
|
|
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." |
|
sd_upsample_prefix = f"up.{3-i}.upsample." |
|
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) |
|
|
|
# up_blocks have three resnets |
|
# also, up blocks in hf are numbered in reverse from sd |
|
for j in range(3): |
|
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." |
|
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." |
|
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) |
|
|
|
# this part accounts for mid blocks in both the encoder and the decoder |
|
for i in range(2): |
|
hf_mid_res_prefix = f"mid_block.resnets.{i}." |
|
sd_mid_res_prefix = f"mid.block_{i+1}." |
|
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) |
|
|
|
|
|
vae_conversion_map_attn = [ |
|
# (stable-diffusion, HF Diffusers) |
|
("norm.", "group_norm."), |
|
("q.", "query."), |
|
("k.", "key."), |
|
("v.", "value."), |
|
("proj_out.", "proj_attn."), |
|
] |
|
|
|
|
|
def reshape_weight_for_sd(w): |
|
# convert HF linear weights to SD conv2d weights |
|
return w.reshape(*w.shape, 1, 1) |
|
|
|
|
|
def convert_vae_state_dict(vae_state_dict): |
|
mapping = {k: k for k in vae_state_dict.keys()} |
|
for k, v in mapping.items(): |
|
for sd_part, hf_part in vae_conversion_map: |
|
v = v.replace(hf_part, sd_part) |
|
mapping[k] = v |
|
for k, v in mapping.items(): |
|
if "attentions" in k: |
|
for sd_part, hf_part in vae_conversion_map_attn: |
|
v = v.replace(hf_part, sd_part) |
|
mapping[k] = v |
|
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} |
|
weights_to_convert = ["q", "k", "v", "proj_out"] |
|
for k, v in new_state_dict.items(): |
|
for weight_name in weights_to_convert: |
|
if f"mid.attn_1.{weight_name}.weight" in k: |
|
print(f"Reshaping {k} for SD format") |
|
new_state_dict[k] = reshape_weight_for_sd(v) |
|
return new_state_dict |
|
|
|
|
|
# =========================# |
|
# Text Encoder Conversion # |
|
# =========================# |
|
# pretty much a no-op |
|
|
|
|
|
def convert_text_enc_state_dict(text_enc_dict): |
|
return text_enc_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") |
|
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") |
|
parser.add_argument("--half", action="store_true", help="Save weights in half precision.") |
|
|
|
args = parser.parse_args() |
|
|
|
assert args.model_path is not None, "Must provide a model path!" |
|
|
|
assert args.checkpoint_path is not None, "Must provide a checkpoint path!" |
|
|
|
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") |
|
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") |
|
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") |
|
|
|
# Convert the UNet model |
|
unet_state_dict = torch.load(unet_path, map_location='cpu') |
|
unet_state_dict = convert_unet_state_dict(unet_state_dict) |
|
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} |
|
|
|
# Convert the VAE model |
|
vae_state_dict = torch.load(vae_path, map_location='cpu') |
|
vae_state_dict = convert_vae_state_dict(vae_state_dict) |
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} |
|
|
|
# Convert the text encoder model |
|
text_enc_dict = torch.load(text_enc_path, map_location='cpu') |
|
text_enc_dict = convert_text_enc_state_dict(text_enc_dict) |
|
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} |
|
|
|
# Put together new checkpoint |
|
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} |
|
if args.half: |
|
state_dict = {k:v.half() for k,v in state_dict.items()} |
|
state_dict = {"state_dict": state_dict} |
|
torch.save(state_dict, args.checkpoint_path) |
@Owlcan the parent folder containing "vae" and "unet" is the folder for the model. You'll want to enter the parent folder's name (can't tell from your screenshot), into the CLI command, and I believe the resulting
.ckpt
file will have the same name as the parent folder, but with a.ckpt
extension.So in other words if hypothetically those folder paths were at
/Users/bob/ai/my-new-model/vae
and/Users/bob/ai/my-new-model/unet
. Then the parent folder would be/Users/bob/ai/my-new-model
. The command should be run from inside the/Users/bob/ai
folder, and look like./toCkpt.sh ./my-new-model
. And lastly the checkpoint file would be created at/Users/bob/ai/my-new-model.ckpt
.I've only texted on Linux, so unable to confirm how well things work on Windows. If you're able to run this inside WSL, that would be ideal. Otherwise, a Unix shell emulator like GitBash might still work, but it's hit or miss.
The Bash script is super short, so I'd recommend just running the python command directly. For the example I provided above, the python command in Powershell/Command Prompt would be
python convertToCkpt.py --model_path=./my-new-model --checkpoint_path=./my-new-model.ckpt