Created
September 24, 2024 15:26
-
-
Save edicam/7d4974e81aa6970fa97ba0f17a2d2e3d to your computer and use it in GitHub Desktop.
Script to convert the .pt slider to an usable .safetensors slider
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
from collections import OrderedDict | |
from safetensors.torch import save_file | |
import torch | |
state_dict=torch.load("path to your slider.pt") | |
alpha_keys = [ | |
'lora_unet_single_transformer_blocks_0_attn_to_q.alpha' | |
] | |
rank_idx0_keys = [ | |
'lora_unet_single_transformer_blocks_0_attn_to_q.lora_down.weight' | |
] | |
alpha = None | |
rank = None | |
for key in rank_idx0_keys: | |
if key in state_dict: | |
rank = int(state_dict[key].shape[0]) | |
break | |
if rank is None: | |
raise ValueError(f'Could not find rank in state dict') | |
for key in alpha_keys: | |
if key in state_dict: | |
alpha = int(state_dict[key]) | |
break | |
if alpha is None: | |
# set to rank if not found | |
alpha = rank | |
#up_multiplier = alpha / rank | |
up_multiplier=0.125 # manual multiplier setting | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.endswith('.alpha'): | |
continue | |
orig_dtype = value.dtype | |
new_val = value.float() * up_multiplier | |
new_key = key | |
new_key = new_key.replace('lora_unet_', 'transformer.') | |
for i in range(100): | |
new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') | |
new_key = new_key.replace('lora_down', 'lora_A') | |
new_key = new_key.replace('lora_up', 'lora_B') | |
new_key = new_key.replace('_lora', '.lora') | |
new_key = new_key.replace('attn_', 'attn.') | |
new_key = new_key.replace('ff_', 'ff.') | |
new_key = new_key.replace('context_net_', 'context.net.') | |
new_key = new_key.replace('0_proj', '0.proj') | |
new_key = new_key.replace('norm_linear', 'norm.linear') | |
new_key = new_key.replace('norm_out_linear', 'norm_out.linear') | |
new_key = new_key.replace('to_out_', 'to_out.') | |
new_state_dict[new_key] = new_val.to(orig_dtype) | |
meta = OrderedDict() | |
meta['format'] = 'pt' | |
save_file(new_state_dict,"path to your output safentensors file") # must be .safetensors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment