Skip to content

Instantly share code, notes, and snippets.

@ntrpnr
Last active October 27, 2022 08:49
Show Gist options
  • Save ntrpnr/f5cc57090d5be82284b5a72db7807136 to your computer and use it in GitHub Desktop.
Save ntrpnr/f5cc57090d5be82284b5a72db7807136 to your computer and use it in GitHub Desktop.
Merge inpainting model (theta_0) with non-inpainting model (theta_1) - Auto1111
# Set inpainting model as model A
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
def get_difference(theta1, theta2):
return theta1 - theta2
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
teritary_model = None
theta_2 = None
theta_funcs = {
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}
theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if key in theta_1:
if str(theta_0[key].shape) != str(theta_1[key].shape):
if key == "model.diffusion_model.input_blocks.0.0.weight":
theta_0[key][:,0:3,:,:] = theta_func2(theta_0[key][:,0:3,:,:], theta_1[key][:,0:3,:,:], multiplier)
print(f"merged {key} with different shapes")
else:
print(f"Unable to merge {key}. Different shapes.")
else:
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
print(f"merged {key}")
else:
print(f"key {key} does not exist in theta_1")
print()
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
sd_models.list_models()
print(f"Checkpoint saved.")
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment