Forked from jachiam/convert_diffusers_to_sd.py
Last active
February 12, 2023 19:08
-
-
Save ArcturusForge/6d5684d2ec55c8e001e7b4ca7136e873 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. | |
# *Only* converts the UNet, VAE, and Text Encoder. | |
# Does not convert optimizer state or any other thing. | |
# Written by jachiam | |
import argparse | |
import os.path as osp | |
import torch | |
# =================# | |
# UNet Conversion # | |
# =================# | |
unet_conversion_map = [ | |
# (stable-diffusion, HF Diffusers) | |
("time_embed.0.weight", "time_embedding.linear_1.weight"), | |
("time_embed.0.bias", "time_embedding.linear_1.bias"), | |
("time_embed.2.weight", "time_embedding.linear_2.weight"), | |
("time_embed.2.bias", "time_embedding.linear_2.bias"), | |
("input_blocks.0.0.weight", "conv_in.weight"), | |
("input_blocks.0.0.bias", "conv_in.bias"), | |
("out.0.weight", "conv_norm_out.weight"), | |
("out.0.bias", "conv_norm_out.bias"), | |
("out.2.weight", "conv_out.weight"), | |
("out.2.bias", "conv_out.bias"), | |
] | |
unet_conversion_map_resnet = [ | |
# (stable-diffusion, HF Diffusers) | |
("in_layers.0", "norm1"), | |
("in_layers.2", "conv1"), | |
("out_layers.0", "norm2"), | |
("out_layers.3", "conv2"), | |
("emb_layers.1", "time_emb_proj"), | |
("skip_connection", "conv_shortcut"), | |
] | |
unet_conversion_map_layer = [] | |
# hardcoded number of downblocks and resnets/attentions... | |
# would need smarter logic for other networks. | |
for i in range(4): | |
# loop over downblocks/upblocks | |
for j in range(2): | |
# loop over resnets/attentions for downblocks | |
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | |
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | |
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | |
if i < 3: | |
# no attention layers in down_blocks.3 | |
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | |
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | |
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | |
for j in range(3): | |
# loop over resnets/attentions for upblocks | |
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | |
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | |
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | |
if i > 0: | |
# no attention layers in up_blocks.0 | |
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | |
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | |
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | |
if i < 3: | |
# no downsample in down_blocks.3 | |
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | |
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | |
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | |
# no upsample in up_blocks.3 | |
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | |
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." | |
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | |
hf_mid_atn_prefix = "mid_block.attentions.0." | |
sd_mid_atn_prefix = "middle_block.1." | |
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | |
for j in range(2): | |
hf_mid_res_prefix = f"mid_block.resnets.{j}." | |
sd_mid_res_prefix = f"middle_block.{2*j}." | |
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | |
def convert_unet_state_dict(unet_state_dict): | |
# buyer beware: this is a *brittle* function, | |
# and correct output requires that all of these pieces interact in | |
# the exact order in which I have arranged them. | |
mapping = {k: k for k in unet_state_dict.keys()} | |
for sd_name, hf_name in unet_conversion_map: | |
mapping[hf_name] = sd_name | |
for k, v in mapping.items(): | |
if "resnets" in k: | |
for sd_part, hf_part in unet_conversion_map_resnet: | |
v = v.replace(hf_part, sd_part) | |
mapping[k] = v | |
for k, v in mapping.items(): | |
for sd_part, hf_part in unet_conversion_map_layer: | |
v = v.replace(hf_part, sd_part) | |
mapping[k] = v | |
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} | |
return new_state_dict | |
# ================# | |
# VAE Conversion # | |
# ================# | |
vae_conversion_map = [ | |
# (stable-diffusion, HF Diffusers) | |
("nin_shortcut", "conv_shortcut"), | |
("norm_out", "conv_norm_out"), | |
("mid.attn_1.", "mid_block.attentions.0."), | |
] | |
for i in range(4): | |
# down_blocks have two resnets | |
for j in range(2): | |
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." | |
sd_down_prefix = f"encoder.down.{i}.block.{j}." | |
vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) | |
if i < 3: | |
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." | |
sd_downsample_prefix = f"down.{i}.downsample." | |
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) | |
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | |
sd_upsample_prefix = f"up.{3-i}.upsample." | |
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) | |
# up_blocks have three resnets | |
# also, up blocks in hf are numbered in reverse from sd | |
for j in range(3): | |
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." | |
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." | |
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) | |
# this part accounts for mid blocks in both the encoder and the decoder | |
for i in range(2): | |
hf_mid_res_prefix = f"mid_block.resnets.{i}." | |
sd_mid_res_prefix = f"mid.block_{i+1}." | |
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) | |
vae_conversion_map_attn = [ | |
# (stable-diffusion, HF Diffusers) | |
("norm.", "group_norm."), | |
("q.", "query."), | |
("k.", "key."), | |
("v.", "value."), | |
("proj_out.", "proj_attn."), | |
] | |
def reshape_weight_for_sd(w): | |
# convert HF linear weights to SD conv2d weights | |
return w.reshape(*w.shape, 1, 1) | |
def convert_vae_state_dict(vae_state_dict): | |
mapping = {k: k for k in vae_state_dict.keys()} | |
for k, v in mapping.items(): | |
for sd_part, hf_part in vae_conversion_map: | |
v = v.replace(hf_part, sd_part) | |
mapping[k] = v | |
for k, v in mapping.items(): | |
if "attentions" in k: | |
for sd_part, hf_part in vae_conversion_map_attn: | |
v = v.replace(hf_part, sd_part) | |
mapping[k] = v | |
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} | |
weights_to_convert = ["q", "k", "v", "proj_out"] | |
for k, v in new_state_dict.items(): | |
for weight_name in weights_to_convert: | |
if f"mid.attn_1.{weight_name}.weight" in k: | |
print(f"Reshaping {k} for SD format") | |
new_state_dict[k] = reshape_weight_for_sd(v) | |
return new_state_dict | |
# =========================# | |
# Text Encoder Conversion # | |
# =========================# | |
# pretty much a no-op | |
def convert_text_enc_state_dict(text_enc_dict): | |
return text_enc_dict | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") | |
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") | |
parser.add_argument("--half", action="store_true", help="Save weights in half precision.") | |
args = parser.parse_args() | |
assert args.model_path is not None, "Must provide a model path!" | |
assert args.checkpoint_path is not None, "Must provide a checkpoint path!" | |
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") | |
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") | |
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") | |
# Convert the UNet model | |
unet_state_dict = torch.load(unet_path, map_location='cpu') | |
unet_state_dict = convert_unet_state_dict(unet_state_dict) | |
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} | |
# Convert the VAE model | |
vae_state_dict = torch.load(vae_path, map_location='cpu') | |
vae_state_dict = convert_vae_state_dict(vae_state_dict) | |
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | |
# Convert the text encoder model | |
text_enc_dict = torch.load(text_enc_path, map_location='cpu') | |
text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | |
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} | |
# Put together new checkpoint | |
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | |
if args.half: | |
state_dict = {k:v.half() for k,v in state_dict.items()} | |
state_dict = {"state_dict": state_dict} | |
torch.save(state_dict, args.checkpoint_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
:Init | |
@echo off | |
set OUTPUT=%1 | |
set MODELPATH=%2 | |
set OUTPUTPATH=%3 | |
if not exist %OUTPUTPATH% mkdir %OUTPUTPATH% | |
goto Install | |
:Install | |
:: Will ensure pytorch is installed. | |
@echo off | |
echo. | |
echo Installing Pytorch... | |
pip install torch | |
goto Convert | |
:Convert | |
:: Applies mods to conversion command. | |
@echo off | |
echo. | |
echo Executing model conversion... | |
@echo on | |
python .\convert_diffusers_to_sd.py --model_path .\%MODELPATH% --checkpoint_path .\%OUTPUT%.ckpt | |
@echo off | |
echo. | |
pause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
:Init | |
@echo off | |
set OUTPUT=%1 | |
set REPO=%2 | |
set PROMPT=%3 | |
set OUTPUTPATH=%4 | |
if not exist %OUTPUTPATH% mkdir %OUTPUTPATH% | |
echo %REPO% > %OUTPUTPATH%\%PROMPT%.txt | |
goto Download | |
:Fail | |
:: Will run if auto folder detection fails to find sub-folder. | |
@echo off | |
echo. | |
echo ERROR: No subfolder found | |
echo Please manually assign local folder locations. | |
start Use_Converter.bat | |
:Download | |
:: Will run the git clone command. | |
@echo off | |
echo. | |
echo Downloading repo... | |
cd %OUTPUTPATH% | |
git clone %REPO% | |
:: Attempts to locate the newest created folder to use in conversion command. | |
for /F "delims=" %%i in ('dir /b /ad-h /t:c /o-d') do ( | |
set DIR=%%i | |
cd .. | |
goto Install | |
) | |
goto Fail | |
:Install | |
:: Will ensure pytorch is installed. | |
@echo off | |
echo. | |
echo Installing Pytorch... | |
pip install torch | |
goto Convert | |
:Convert | |
:: Applies mods to conversion command. | |
@echo off | |
echo. | |
echo Executing model conversion... | |
@echo on | |
python .\convert_diffusers_to_sd.py --model_path .\%OUTPUTPATH%\%DIR% --checkpoint_path .\%OUTPUT%.ckpt | |
@echo off | |
echo. | |
pause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
:Init | |
@echo off | |
echo. | |
echo Output from current directory for checkpoint file? | |
echo (repos will be downloaded there too) | |
echo Example: DiscoDiffusionStyle | |
set /p OUTPUTPATH=Input: | |
echo. | |
echo Checkpoint filename? | |
echo Example: DiscoDiffusionModel | |
set /p MODNAME=Input: | |
set FINALOUTPUT=%OUTPUTPATH%\%MODNAME% | |
goto Source | |
:Source | |
@echo off | |
echo. | |
echo Git download diffusers repo? | |
set /p GIT=(Y/N): | |
if %GIT%==Y ( | |
goto Repo | |
) else if %GIT%==y ( | |
goto Repo | |
) else goto Local | |
:Repo | |
@echo off | |
echo. | |
echo Repo path | |
echo Example: https://huggingface.co/sd-dreambooth-library/disco-diffusion-style | |
set /p REPO=Input: | |
echo. | |
echo Model prompt key? | |
echo Example: ddfusion style | |
set /p PROMPT=Input: | |
call ConvertFromRepo.bat %FINALOUTPUT% %REPO% %PROMPT% %OUTPUTPATH% | |
goto Query | |
:Local | |
@echo off | |
echo. | |
echo Path to model from current directory? | |
echo Example: DiscoDiffusionStyle\disco-diffusion-style | |
set /p MODELPATH=Input: | |
call ConvertFromLocal.bat %FINALOUTPUT% %MODELPATH% %OUTPUTPATH% | |
goto Query | |
:Query | |
@echo off | |
echo. | |
echo Convert another model? | |
set /p AGAIN=(Y/N): | |
if %AGAIN%==Y ( | |
goto Init | |
) else if %AGAIN%==y ( | |
goto Init | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Find the compiled exe version of this here:
https://arcturusforge.itch.io/diff-to-ckpt