Skip to content

Instantly share code, notes, and snippets.

@jachiam
Last active July 18, 2024 10:24
Show Gist options
  • Save jachiam/8a5c0b607e38fcc585168b90c686eb05 to your computer and use it in GitHub Desktop.
Save jachiam/8a5c0b607e38fcc585168b90c686eb05 to your computer and use it in GitHub Desktop.
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
# *Only* converts the UNet, VAE, and Text Encoder.
# Does not convert optimizer state or any other thing.
# Written by jachiam
import argparse
import os.path as osp
import torch
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
# pretty much a no-op
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
args = parser.parse_args()
assert args.model_path is not None, "Must provide a model path!"
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location='cpu')
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location='cpu')
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location='cpu')
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if args.half:
state_dict = {k:v.half() for k,v in state_dict.items()}
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
@Roderich96
Copy link

Hello,This is a really interesting code that i read it can help to use the model created in dreambooth and download it,

I just been wondering, as a noob in codes, How can I initialize it? Is there any guide i may use?

In advance thank you so much to respond my comment.

@hashnag
Copy link

hashnag commented Oct 19, 2022

From what I've seen:
python .\convert_diffusers_to_sd.py --model_path .\your-model-directory --checkpoint_path .\your-ckpt-directory\your-model.ckpt

@Christopher-Hayes
Copy link

Thanks @jachiam this script works great. If anyone wants a bash script to make running this a little easier, I put one here: https://gist.github.com/Christopher-Hayes/636ba25e0ae2e7020722d5386ac2571b
Which would allow you to run just ./toCkpt.sh ./model_folder in the CLI.

@D-Ogi
Copy link

D-Ogi commented Nov 3, 2022

Hi! Great script! I'm wondering how to reverse that procedure? I've got some nice merge of SD models that I'd like to alter with dreambooth. Do you maybe know how to reverse that conversion process? I guess it's not only about starting from the end, but some associated files need to generated as well.

@megaoutput
Copy link

I am getting this issue below

C:\Users\megao\Desktop\aiimage\ai_test\ai_worlds\older>python .\convert_diffusers_to_sd.py --model_path .\smilingfriendstyle --checkpoint_path .\smilingfriendstyle\smilingfriendstyle.ckpt
Traceback (most recent call last):
File "C:\Users\megao\Desktop\aiimage\ai_test\ai_worlds\older\convert_diffusers_to_sd.py", line 216, in
unet_state_dict = torch.load(unet_path, map_location='cpu')
File "C:\Users\megao\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\serialization.py", line 699, in load
with _open_file_like(f, 'rb') as opened_file:
File "C:\Users\megao\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\serialization.py", line 230, in _open_file_like
return _open_file(name_or_buffer, mode)
File "C:\Users\megao\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\serialization.py", line 211, in init
super(_open_file, self).init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '.\smilingfriendstyle\unet\diffusion_pytorch_model.bin'

C:\Users\megao\Desktop\aiimage\ai_test\ai_worlds\older>

@MiyatoKyo
Copy link

(d2sd) E:\convert_diffusers_to_sd\d2sd\Scripts>python convert_diffusers_to_sd.py --model_path diffusion_pytorch_model.bin --checkpoint_path artstation-diffusion.ckpt
Traceback (most recent call last):
File "E:\convert_diffusers_to_sd\d2sd\Scripts\convert_diffusers_to_sd.py", line 216, in
unet_state_dict = torch.load(unet_path, map_location='cpu')
File "E:\convert_diffusers_to_sd\d2sd\lib\site-packages\torch\serialization.py", line 771, in load
with _open_file_like(f, 'rb') as opened_file:
File "E:\convert_diffusers_to_sd\d2sd\lib\site-packages\torch\serialization.py", line 270, in _open_file_like
return _open_file(name_or_buffer, mode)
File "E:\convert_diffusers_to_sd\d2sd\lib\site-packages\torch\serialization.py", line 251, in init
super(_open_file, self).init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: 'diffusion_pytorch_model.bin\unet\diffusion_pytorch_model.bin'

@0xsamgreen
Copy link

0xsamgreen commented Nov 23, 2022

Hi,

I'm trying to use this Aitrepreneur tutorial: https://www.youtube.com/watch?v=tgRiZzwSdXg

I downloaded this repo from the sd-dreambooth-library then I ran

python convert_diffusers_to_sd.py --model_path persona-5-shigenori-style --checkpoint_path persona-5-shigenori-style/shigenoridiffusion.ckpt

and I got the following error:

Traceback (most recent call last):
  File "/Users/sam/Documents/sd-experiments/convert_diffusers_to_sd.py", line 216, in <module>
    unet_state_dict = torch.load(unet_path, map_location='cpu')
  File "/Users/sam/opt/anaconda3/envs/sd-experiments/lib/python3.9/site-packages/torch/serialization.py", line 795, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/Users/sam/opt/anaconda3/envs/sd-experiments/lib/python3.9/site-packages/torch/serialization.py", line 1002, in _legacy_load
    magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, 'v'.

I've tried this script with three other styles downloaded from Hugging Face and I get the same error. Has anyone seen or fixed this?

@Miyuutsu
Copy link

So for anyone that hasn't figured it out, the errors above is entirely unrelated to this script. Make sure you fully follow the instructions. You need to install git-lfs and run git lfs install then clone the model you want again.

@janoschsimon
Copy link

Hey there first thx for the amazing work :) i manged to get Dreambooth running on my trusty 2080ti and the training works with dreambooth. i have results in the 400 and 800 folder, but when i try to load the ckpt it dont apear in stable-diffusion-webui. im using this script for the conversion to ckpt https://gist.github.com/Christopher-Hayes/636ba25e0ae2e7020722d5386ac2571b any idea why it dont show up in stable-diffusion-webui?

here is the code i run dreambooth with. i can also upload the ckpt or the result folder to gdrive if that helps?

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="training"
export CLASS_DIR="classes"
export OUTPUT_DIR="../../../output"

accelerate launch train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--instance_data_dir=$INSTANCE_DIR
--class_data_dir=$CLASS_DIR
--output_dir=$OUTPUT_DIR
--revision="fp16"
--with_prior_preservation --prior_loss_weight=1.0
--seed=3434554
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=1
--max_train_steps=800
--save_interval=400
--save_sample_prompt="photo of smzkz"
--concepts_list="concepts_list.json"

thx and have a nice day
Janosch

@tedliosu
Copy link

tedliosu commented Sep 3, 2023

Thank you so much for the work you've put into making this script, but unfortunately now the script no longer works because the weights under text_encoder, unet, and vae directories are saved using Huggingface's new safetensors file format instead of pytorch's traditional bin file format. Also stable diffusion frontends like AUTOMATIC1111's stable-diffusion-webui now expects a yaml config file for every custom model like the one here, so this script will have to also generate at least one additional file for every custom model generated using dreambooth as well if I'm not mistaken.

Please let us know when you may get around to updating this script so that it works with safetensors files too instead of just the traditional pytorch bin files.

@tedliosu
Copy link

tedliosu commented Oct 19, 2023

Thank you so much for the work you've put into making this script, but unfortunately now the script no longer works because the weights under text_encoder, unet, and vae directories are saved using Huggingface's new safetensors file format instead of pytorch's traditional bin file format. Also stable diffusion frontends like AUTOMATIC1111's stable-diffusion-webui now expects a yaml config file for every custom model like the one here, so this script will have to also generate at least one additional file for every custom model generated using dreambooth as well if I'm not mistaken.

Please let us know when you may get around to updating this script so that it works with safetensors files too instead of just the traditional pytorch bin files.

My utmost apologies; please see here and here, both of which show that the version of this script that works with safetensors files already exists.

@Yntec
Copy link

Yntec commented Jul 18, 2024

(EDIT - This post originally was asking for help, but the space that could successfully convert diffusers to safetensors was created here: https://huggingface.co/spaces/John6666/convert_repo_to_safetensors_sd )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment