Skip to content

Instantly share code, notes, and snippets.

@wkpark
Forked from ProGamerGov/replace_vae.py
Last active June 22, 2023 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wkpark/f70bc55f23c479e302dc4a5ebd5ae1c2 to your computer and use it in GitHub Desktop.
Save wkpark/f70bc55f23c479e302dc4a5ebd5ae1c2 to your computer and use it in GitHub Desktop.
Replace the VAE in a Stable Diffusion model with a new VAE. Tested on v1.4 & v1.5 SD models
#
# Script by https://github.com/ProGamerGov
#
# ChangeLog:
# - support safetensors, save float16 if needed, check filename 2023/06/22 by wkpark
#
import os
import sys
import copy
import torch
from pathlib import Path
from safetensors.torch import load_file, save_file
def load_model(path):
if path.suffix == ".safetensors":
return load_file(path, device="cpu")
else:
ckpt = torch.load(path, map_location="cpu")
return ckpt["state_dict"] if "state_dict" in ckpt else ckpt
# Path to model and VAE files that you want to merge
if len(sys.argv) == 1:
print("Usage: replace_vae.py model_file vae_file")
exit(1)
model_file_path = Path(sys.argv[1])
if len(sys.argv) > 2:
vae_file_path = Path(sys.argv[2])
else:
vae_file_path = Path("vae-ft-mse-840000-ema-pruned.safetensors")
if not vae_file_path.exists():
for dir in ".", "../VAE":
default_vae = "vae-ft-mse-840000-ema-pruned"
for ext in "safetensors", "ckpt":
vae_file_path = Path(os.path.join(dir, default_vae + "." + ext))
if not vae_file_path.exists():
continue
break
if vae_file_path.exists():
print(f"- vae file {str(vae_file_path)} found!")
break
if not vae_file_path.exists():
print(f"no default vae file {default_vae} found!")
exit(1)
if not model_file_path.exists() and model_file_path.suffix == "":
model_file = sys.argv[1]
for ext in "safetensors", "ckpt":
model_file_path = Path(model_file + "." + ext)
if not model_file_path.exists():
continue
break
if not model_file_path.exists():
print(f"no model file {model_file} found!")
exit(1)
print(f"- vae file = {str(vae_file_path)}")
print(f"- model file = {str(model_file_path)}")
# Name to use for new model file
new_model_path = model_file_path.parent / (model_file_path.stem + "-vae" + model_file_path.suffix)
# Load files
vae_model = load_model(vae_file_path)
full_model = load_model(model_file_path)
# check original dtype
if full_model["cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"].dtype == torch.float32:
half = False
else:
half = True
# Replace VAE in model file with new VAE
vae_dict = {k: v for k, v in vae_model.items() if k[0:4] not in ["loss", "mode"]}
for k, _ in vae_dict.items():
key_name = "first_stage_model." + k
full_model[key_name] = copy.deepcopy(vae_model[k])
if half and type(full_model[key_name]) == torch.Tensor and full_model[key_name].dtype == torch.float32:
full_model[key_name] = full_model[key_name].half()
# Save model with new VAE
if new_model_path.suffix == ".safetensors":
save_file(full_model, str(new_model_path))
else:
torch.save({"state_dict": full_model}, str(new_model_path))
print(f"new file {str(new_model_path)} saved!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment