Skip to content

Instantly share code, notes, and snippets.

@akii-i
Created April 18, 2023 16:30
Show Gist options
  • Save akii-i/67ad27838890237aa52f2fa509405af1 to your computer and use it in GitHub Desktop.
Save akii-i/67ad27838890237aa52f2fa509405af1 to your computer and use it in GitHub Desktop.
from safetensors import safe_open
from safetensors.torch import save_file
def layer_name(i):
if i.startswith('lora_te_text_model_encoder_layers_'):
idx = int(i[34:].split('_')[0])
return f'text'
elif i.startswith('lora_unet_down_blocks_'):
idx0 = int(i[22:].split('_')[0])
idx1 = int(i[22:].split('_')[2])
idx = 1 + idx0 * 3 + idx1
return f'down_{idx}'
elif i.startswith('lora_unet_up_blocks_'):
idx0 = int(i[20:].split('_')[0])
idx1 = int(i[20:].split('_')[2])
idx = idx0 * 3 + idx1
return f'up_{idx}'
else:
idx = 1
return f'mid_{idx}'
name = 'some_lora'
lora = safe_open(f'path/to/stable-diffusion-webui/models/lora/{name}.safetensors', framework='pt')
layers_to_keep = [
'text',
'down_1', 'down_2', 'down_4', 'down_5', 'down_7', 'down_8',
'mid_1',
# 'up_3', 'up_4', 'up_5', 'up_6', 'up_7', 'up_8', 'up_9', 'up_10', 'up_11'
]
tmp = dict()
for i in lora.keys():
if layer_name(i) in layers_to_keep:
tmp[i] = lora.get_tensor(i)
else:
# tmp[i] = lora.get_tensor(i) * 0.0
pass
save_file(tmp, f'path/to/stable-diffusion-webui/models/lora/{name}_pruned.safetensors')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment