Skip to content

Instantly share code, notes, and snippets.

@saftle
Forked from ProGamerGov/replace_vae.py
Last active January 21, 2023 10:00
Show Gist options
  • Save saftle/c5e222c6231e7b19f01bb93ac9fcc191 to your computer and use it in GitHub Desktop.
Save saftle/c5e222c6231e7b19f01bb93ac9fcc191 to your computer and use it in GitHub Desktop.
Fixed script to work with models that were merged with Automatic1111, and included the numpy dependency which the script uses.
# Script by https://github.com/ProGamerGov
import copy
import torch
import numpy as np
# Path to model and VAE files that you want to merge
vae_file_path = "vae-ft-mse-840000-ema-pruned.ckpt"
model_file_path = "v1-5-pruned-emaonly.ckpt"
# Name to use for new model file
new_model_name = "v1-5-pruned-emaonly_ema_vae.ckpt"
# Load files
vae_model = torch.load(vae_file_path, map_location="cpu")
full_model = torch.load(model_file_path, map_location="cpu")
if 'state_dict' in vae_model:
vae_model = vae_model['state_dict']
if 'state_dict' in full_model:
full_model = full_model['state_dict']
# Replace VAE in model file with new VAE
vae_dict = {k: v for k, v in vae_model.items() if k[0:4] not in ["loss", "mode"]}
for k, _ in vae_dict.items():
key_name = "first_stage_model." + k
full_model[key_name] = copy.deepcopy(vae_model[k])
# Save model with new VAE
torch.save(full_model, new_model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment