Skip to content

Instantly share code, notes, and snippets.

@TomoshibiAkira
Last active June 4, 2024 16:51
Show Gist options
  • Save TomoshibiAkira/151a2353b946aa9cd8d4d2cdabc31245 to your computer and use it in GitHub Desktop.
Save TomoshibiAkira/151a2353b946aa9cd8d4d2cdabc31245 to your computer and use it in GitHub Desktop.
Compare MiniCPM and LLaMa3V weights.
import os
import glob
import numpy as np
from safetensors import safe_open
LLAMA3V = 'llama3v'
MINICPM = 'minicpm-llama3'
def get_keys(path):
fns = sorted(glob.glob(os.path.join(path, '*.safetensors')))
d = []
fd = []
for fn in fns:
with safe_open(fn, framework='pt', device='cpu') as f:
d.extend(f.keys())
fd.extend(list(map(lambda x: (x, fn), f.keys())))
return set(d), fd
l3v, l3vf = get_keys(LLAMA3V)
cpm, cpmf = get_keys(MINICPM)
l3vf = sorted(l3vf, key=lambda x: x[0])
cpmf = sorted(cpmf, key=lambda x: x[0])
# make sure the names are the same. they actually are.
assert len(l3v - (l3v & cpm)) == 0
assert len(cpm - (l3v & cpm)) == 0
means = []
stds = []
awms = []
for x, y in zip(l3vf, cpmf):
n = x[0]
assert n == y[0]
with safe_open(x[1], framework='pt', device=0) as f:
x = f.get_tensor(x[0])
with safe_open(y[1], framework='pt', device=0) as f:
y = f.get_tensor(y[0])
lwm = x.reshape(-1).abs().mean()
cwm = y.reshape(-1).abs().mean()
delta = x - y
delta = delta.reshape(-1)
awm = delta.abs().mean()
mean = delta.mean()
std = (((delta - mean) ** 2).sum() / delta.shape[0]).sqrt()
means.append(mean.item())
stds.append(std.item())
awms.append(awm.item())
print (f"{n}: delta mean={mean}, delta std={std}, abs(delta) mean={awm}, L3V weight mean={lwm}, CPM weight mean={cwm}")
print (np.histogram(means))
print (np.histogram(stds))
print (np.histogram(awms))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment