Skip to content

Instantly share code, notes, and snippets.

@Quasimondo
Created October 22, 2022 15:23
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Quasimondo/f344659f57dc15bd7892a969bd58ac67 to your computer and use it in GitHub Desktop.
Save Quasimondo/f344659f57dc15bd7892a969bd58ac67 to your computer and use it in GitHub Desktop.
Quick script to merge finetuned StabilityAI autoencoder into RunwayML Stable Diffusion 1.5 checkpoint
import torch
#USE AT YOUR OWN RISK
#local path to runwayML SD 1.5 checkpoint (https://huggingface.co/runwayml/stable-diffusion-v1-5)
ckpt_15 = "./v1-5-pruned-emaonly.ckpt"
#local path to StabilityAI finetuned autoencoder (https://huggingface.co/stabilityai/sd-vae-ft-mse)
ckpt_vae = "./vae-ft-mse-840000-ema-pruned.ckpt"
#path to save merged model to
ckpt_out = "./v1-5-pruned-emaonly_new_vae.ckpt"
pl_sd = torch.load(ckpt_15, map_location="cpu")
sd = pl_sd["state_dict"]
over_sd = torch.load(ckpt_vae,map_location="cpu")["state_dict"]
sdk = sd.keys()
for key in over_sd.keys():
if "first_stage_model."+key in sdk:
sd["first_stage_model."+key] = over_sd[key]
print(key,"overwritten")
torch.save(pl_sd,ckpt_out)
@zhuofengli
Copy link

I got the following error:
KeyError: 'state_dict'

any idea how to fix that?

@Quasimondo
Copy link
Author

Hard to tell - I suspect you must either be using a wrong checkpoint or the checkpoints at the download locations have been changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment