Skip to content

Instantly share code, notes, and snippets.

@edicam
Created September 24, 2024 15:26
Show Gist options
  • Save edicam/7d4974e81aa6970fa97ba0f17a2d2e3d to your computer and use it in GitHub Desktop.
Save edicam/7d4974e81aa6970fa97ba0f17a2d2e3d to your computer and use it in GitHub Desktop.
Script to convert the .pt slider to an usable .safetensors slider
import os.path
from collections import OrderedDict
from safetensors.torch import save_file
import torch
state_dict=torch.load("path to your slider.pt")
alpha_keys = [
'lora_unet_single_transformer_blocks_0_attn_to_q.alpha'
]
rank_idx0_keys = [
'lora_unet_single_transformer_blocks_0_attn_to_q.lora_down.weight'
]
alpha = None
rank = None
for key in rank_idx0_keys:
if key in state_dict:
rank = int(state_dict[key].shape[0])
break
if rank is None:
raise ValueError(f'Could not find rank in state dict')
for key in alpha_keys:
if key in state_dict:
alpha = int(state_dict[key])
break
if alpha is None:
# set to rank if not found
alpha = rank
#up_multiplier = alpha / rank
up_multiplier=0.125 # manual multiplier setting
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith('.alpha'):
continue
orig_dtype = value.dtype
new_val = value.float() * up_multiplier
new_key = key
new_key = new_key.replace('lora_unet_', 'transformer.')
for i in range(100):
new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.')
new_key = new_key.replace('lora_down', 'lora_A')
new_key = new_key.replace('lora_up', 'lora_B')
new_key = new_key.replace('_lora', '.lora')
new_key = new_key.replace('attn_', 'attn.')
new_key = new_key.replace('ff_', 'ff.')
new_key = new_key.replace('context_net_', 'context.net.')
new_key = new_key.replace('0_proj', '0.proj')
new_key = new_key.replace('norm_linear', 'norm.linear')
new_key = new_key.replace('norm_out_linear', 'norm_out.linear')
new_key = new_key.replace('to_out_', 'to_out.')
new_state_dict[new_key] = new_val.to(orig_dtype)
meta = OrderedDict()
meta['format'] = 'pt'
save_file(new_state_dict,"path to your output safentensors file") # must be .safetensors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment