Skip to content

Instantly share code, notes, and snippets.

@opparco
Created September 14, 2022 13:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save opparco/62511cea8b185fc5238c62099a89f593 to your computer and use it in GitHub Desktop.
Save opparco/62511cea8b185fc5238c62099a89f593 to your computer and use it in GitHub Desktop.
Generate a mixed model of the two models.
import argparse
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt0",
type=str,
default="sd-v1-4.ckpt",
help="path to checkpoint of model 0",
)
parser.add_argument(
"--ckpt1",
type=str,
default="wd-v1-2-full-ema.ckpt",
help="path to checkpoint of model 1",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="alpha: (1-alpha) * theta_0 + alpha * theta_1",
)
parser.add_argument(
"--outpath",
type=str,
default="tempered-waifu.ckpt",
help="path to checkpoint of tempered model",
)
opt = parser.parse_args()
model_0 = torch.load(opt.ckpt0)
model_1 = torch.load(opt.ckpt1)
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
alpha = opt.alpha
outpath = opt.outpath
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = (1-alpha) * theta_0[key] + alpha * theta_1[key]
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
print(f"Save checkpoint of tempered model: \n{outpath} \n")
torch.save(model_0, opt.outpath)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment