Skip to content

Instantly share code, notes, and snippets.

@loboere
Created January 6, 2024 17:31
Show Gist options
  • Save loboere/1793f595c2a7369d28af65857fdaf532 to your computer and use it in GitHub Desktop.
Save loboere/1793f595c2a7369d28af65857fdaf532 to your computer and use it in GitHub Desktop.
interpolate rvc models
from collections import OrderedDict
import torch
def extract(ckpt):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
#example
#interpolate_rvc_models("/models/modelA.pth",
# "/models/modelB.pth",
# 0.5,
# "modelC")
def interpolate_rvc_models(modelA,modelB,inter,out_name):
alpha1=inter
ckpt1 = torch.load(modelB, map_location="cpu")
ckpt2 = torch.load(modelA, map_location="cpu")
cfg = ckpt1["config"]
if "model" in ckpt1:
ckpt1 = extract(ckpt1)
else:
ckpt1 = ckpt1["weight"]
if "model" in ckpt2:
ckpt2 = extract(ckpt2)
else:
ckpt2 = ckpt2["weight"]
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
return "Fail to merge the models. The model architectures are not the same."
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt1.keys():
# try:
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
opt["weight"][key] = (
alpha1 * (ckpt1[key][:min_shape0].float())
+ (1 - alpha1) * (ckpt2[key][:min_shape0].float())
).half()
else:
opt["weight"][key] = (
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
).half()
opt["config"] = cfg
opt["sr"] = "40k"
opt["f0"] = 1
opt["info"] = "100epoch"
torch.save(opt, out_name+".pth")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment