Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created November 12, 2023 18:27
Show Gist options
  • Save rockerBOO/6d2dbc7827c83bf4273e7381636ce9ff to your computer and use it in GitHub Desktop.
Save rockerBOO/6d2dbc7827c83bf4273e7381636ce9ff to your computer and use it in GitHub Desktop.
import argparse
import torch
from safetensors.torch import load_file, safe_open
from library import model_util
def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def get_norms(state_dict, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
longest_key = 0
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (
(up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(
1, 0, 2, 3
)
else:
updown = up @ down
updown *= scale
save_key = downkeys[i].replace(".lora_down", "")
longest_key = (
len(save_key) if len(save_key) > longest_key else longest_key
)
norms.append({save_key: updown.norm().item()})
return norms, longest_key
def main(args):
lora_sd, metadata = load_state_dict(args.model, torch.float32)
norms, longest_key = get_norms(
lora_sd, "cuda" if torch.cuda.is_available() else "cpu"
)
for norm in norms:
for k, v in norm.items():
print(f"{k:<{longest_key}} {v}")
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description="Check the norm values for the weights in a LoRA model"
)
argparser.add_argument("model", help="LoRA model to check the norms of")
args = argparser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment