Skip to content

Instantly share code, notes, and snippets.

@sftblw
Last active January 25, 2023 07:34
Show Gist options
  • Save sftblw/1586a5b0d962e606ac5e89754fdd24cd to your computer and use it in GitHub Desktop.
Save sftblw/1586a5b0d962e606ac5e89754fdd24cd to your computer and use it in GitHub Desktop.
fix convert_original_stable_diffusion_to_diffusers.py for fixing merged models

fix convert_original_stable_diffusion_to_diffusers.py for fixing merged models

To fix broken model (mainly merged weights from AUTOMATIC111's no-licensed UI for Invoke AI ,

one can use CKPT -> HF/diffusers -> CKPT route.

Environment setup

OmegaConf, Pytorch-Lightning, and usual huggingface diffusers environment.

Get scripts

I used scripts at cd91fc06fe9513864fca6a57953ca85a7ae7836e .

why it fails

This fails at first time.

why it fails: main checkpoint

With debugging, I found that loaded checkpoints (by torch.load(CKPT_PATH)) is essentially a Python dictionary with torch.Tensor s. (I guess it's a python pickle file, which is serialized python object)

So, main reason is, merged checkpoint is without 'state_dict' key.

# My merged weight looks like this:
merged_weight = {
  'model.blahblah.key': torch.tensor(...)
}

# Another normal working file is like this:
normal_weight = {
  'state_dict': {
    'model.blahblah.key': torch.tensor(...)
  }
}

essentially, HF script discards anything other than 'state_dict', we can easily adjust to use provided dictionary.

# HF's script
global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]

why it fails: VAE

VAE is under "first_stage_model." of main ckpt, according to def convert_ldm_vae_checkpoint(checkpoint, config): of HF's script.

Some models provide external VAE checkpoints, but thoes ckpts' keys do not start with "first_stage_model.". If you do not provide VAE, it will use some pre-existing VAE in the main ckpt. this is critical some models, especially similar model to NovelAIs.

external VAE's ckpt is also loadable with torch.load(VAE_CKPT_PATH).

modifying convert_original_stable_diffusion_to_diffusers.py

So, by inserting two things, the script will be fixed.

argument add

    parser.add_argument(
        "--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
    )

after loading checkpoint

    checkpoint = torch.load(args.checkpoint_path) # add below this
    
# +++++++++++++++++++++
    # dummy
    if 'state_dict' not in checkpoint.keys():
        checkpoint = {
            'state_dict': checkpoint
        }
    
    # the script looks for the `global_step` for discriminating between SD v1.x and v2.x
    if 'global_step' not in checkpoint.keys():
        checkpoint['global_step'] = 678999

    global_step = checkpoint["global_step"]
    checkpoint = checkpoint["state_dict"]

    if args.vae_checkpoint_path is not None:
        vae_checkpoint = torch.load(args.vae_checkpoint_path)

        if 'state_dict' in vae_checkpoint.keys():
            vae_checkpoint = vae_checkpoint['state_dict']

        # (updated)
        # overwrite ckpt
        for k, v in vae_checkpoint.items():
            vae_key = "first_stage_model." + k
            checkpoint[vae_key] = v
        vae_checkpoint = None
# +++++++++++++++++++++

HF diffusers -> original SD checkpoint

no work needed, at least for me.

executing

For reference, my execution was like this:

# CKPT -> HF
python scripts/convert_original_stable_diffusion_to_diffusers.py \
  --checkpoint_path ./models_ckpt/mymodel.ckpt \
  --vae_checkpoint_path ./models_ckpt/mymodel_vae.ckpt \
  --dump_path ./models_hf/hf_model_mymodel \
  --scheduler_type euler-ancestral

# HF -> CKPT
python scripts/convert_diffusers_to_original_stable_diffusion.py \
  --model_path ./models_hf/hf_model_mymodel \
  --checkpoint_path ./models_ckpt/mymodel_modified.ckpt

etc

@ItCameFr0mMars
Copy link

Traceback (most recent call last):
  File "C:\Users\username\Desktop\stable diff\convert_original_stable_diffusion_to_diffusers.py", line 824, in <module>
    for k, v in vae_checkpoint.items():
NameError: name 'vae_checkpoint' is not defined. Did you mean: 'checkpoint'?

Followed your instructions exactly, ended up with this error.
here is my file https://0bin.net/paste/TytzkTh6#HhVsJSgeHmL89969wCbwet8Lv6kyNw69AeB7axUZUUz

@sftblw
Copy link
Author

sftblw commented Jan 24, 2023

    # the script looks for the `global_step` for discriminating between SD v1.x and v2.x
    if 'global_step' not in checkpoint.keys():
        checkpoint['global_step'] = 678999
 
>  # checkpoint is already ripped out here
    global_step = checkpoint["global_step"]
    checkpoint = checkpoint["state_dict"]
 
    if args.vae_checkpoint_path is not None:
        vae_checkpoint = torch.load(args.vae_checkpoint_path)
 
        if 'state_dict' in vae_checkpoint.keys():
            vae_checkpoint = vae_checkpoint['state_dict']
 
    # overwrite ckpt
<    for k, v in vae_checkpoint.items():
<        vae_key = "first_stage_model." + k
<        checkpoint[vae_key] = v
<    vae_checkpoint = None
>        # vae_checkpoint only exists when above line ( torch.load ) runs,
>        # which means additional ckpt is provided for it
>        # When I modified this part, I didn't test without VAE argument.
>        for k, v in vae_checkpoint.items():
>            vae_key = "first_stage_model." + k
>            checkpoint[vae_key] = v
>        vae_checkpoint = None
# +++++++++++++++++++++
<    global_step = checkpoint["global_step"]
<    checkpoint = checkpoint["state_dict"]

@ItCameFr0mMars
Copy link

ItCameFr0mMars commented Jan 24, 2023

I tried that and got this error:

(virtualenv) C:\Users\username\Desktop\stable diff>python convert_original_stable_diffusion_to_diffusers.py --checkpoint_path="./aeros.ckpt" --dump_path="./aeros" --original_config_file=v1-inference.yaml
Traceback (most recent call last):
  File "C:\Users\username\Desktop\stable diff\convert_original_stable_diffusion_to_diffusers.py", line 910, in <module>
    unet.load_state_dict(converted_unet_checkpoint)
  File "C:\Users\username\Desktop\stable diff\virtualenv\lib\site-packages\torch\nn\modules\module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
        Missing key(s) in state_dict: "up_blocks.0.upsamplers.0.conv.weight", "up_blocks.0.upsamplers.0.conv.bias", "up_blocks.1.upsamplers.0.conv.weight", "up_blocks.1.upsamplers.0.conv.bias", "up_blocks.2.upsamplers.0.conv.weight", "up_blocks.2.upsamplers.0.conv.bias".
        Unexpected key(s) in state_dict: "up_blocks.0.attentions.2.conv.bias", "up_blocks.0.attentions.2.conv.weight".

@sftblw
Copy link
Author

sftblw commented Jan 25, 2023

Maybe you are using SD 2.x model? I don't know structure of SD in detail, but error message says there are some unexpected & missing matrices especially UNet upsampling block I guess?. This script is not updating old one and I didn't check working with SD 2.X models.

@ItCameFr0mMars
Copy link

@sftblw
Copy link
Author

sftblw commented Jan 25, 2023

For now I'm not interested in detailed investigation. Description of that model indicate that model is based on SD v1.5, I mainly tested with Waifu Diffusion v1.3 which is based on SD v1.4. Different naming of weight matrices could be mapped to correct one with proper knowledge of detailed model structure of two library, which is I don't have.

maybe updated version of original script (before modulization) might be helpful?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment