Skip to content

Instantly share code, notes, and snippets.

@Ttl
Created January 2, 2024 16:10
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ttl/0d51f739dc59254b4b2183e259c97d82 to your computer and use it in GitHub Desktop.
Save Ttl/0d51f739dc59254b4b2183e259c97d82 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
Calculate KL-divergence of two models output logits on data set.
First call the program with write_path and text_path using fp16 model.
./llama_kl.py -m <fp16 model> -t <wiki.test.raw> -w <logits.gz>
This writes logits to file. Then call the program with quantized model with read path
./llama_kl.py -m <quantized model> -r <logits.gz>
KL-divergence to the first run is calculated.
See ./llama_kl.py --help for more options.
"""
import llama_cpp
import numpy as np
import sys
import argparse
import os.path
import struct
import ast
from scipy.special import rel_entr, softmax
import gzip
import pickle
from scipy.stats.mstats import mquantiles_cimj
from scipy.stats import bayes_mvs
from scipy.stats import t as student_t
import random
import time
def kl_div(p, q):
p = softmax(p)
q = softmax(q)
return np.sum(rel_entr(p, q))
def write_header(f, args, ctx, vocab_len, batch):
f.write("llama_kl_divergence_v1\n".encode('utf-8'))
d = vars(args)
d["n_ctx"] = ctx
d["n_vocab"] = vocab_len
d["n_batch"] = batch
f.write((str(d)+"\n").encode('utf-8'))
def read_header(f):
header = "llama_kl_divergence_v1\n".encode('utf-8')
if f.read(len(header)) != header:
raise ValueError("Invalid header in input logit file")
args = ast.literal_eval(f.readline().decode('utf-8').strip())
return args
def write_logits(f, tokens, logits):
f.write(struct.pack("<I", len(tokens)))
f.write(struct.pack("<I", len(logits)))
f.write(struct.pack("<I", len(logits[0])))
t = np.array(tokens, dtype=np.uint32).tobytes()
assert len(t) == 4 * len(tokens)
f.write(t)
l = np.array(logits, dtype=np.float32).tobytes()
assert len(l) == 4 * len(logits) * len(logits[0])
f.write(l)
def read_logits(f):
n_tokens = f.read(4)
if len(n_tokens) != 4:
# EOF
return None, None
n_tokens = struct.unpack("<I", n_tokens)[0]
n_logits = struct.unpack("<I",f.read(4))[0]
n_vocab = struct.unpack("<I",f.read(4))[0]
tokens = [int(i) for i in np.frombuffer(f.read(n_tokens * 4), dtype=np.uint32)]
logits = np.frombuffer(f.read(n_logits * n_vocab * 4), dtype=np.float32).reshape(n_logits, n_vocab)
return tokens, logits
def main(args):
ctx = args.n_ctx
read_file = None
if args.read_path is not None:
print(f"Computing KL-divergence against: {args.read_path}")
read_file = gzip.open(args.read_path, "rb")
input_args = read_header(read_file)
ctx = input_args["n_ctx"]
model = llama_cpp.Llama(model_path=args.model, n_ctx=ctx, n_batch=args.n_batch,
logits_all=True, n_gpu_layers=args.n_gpu_layers, verbose=args.verbose)
model_name = os.path.split(args.model)[1]
tokens = None
if args.text_path and args.read_path is None:
with open(args.text_path, "r") as f:
prompt = f.read()
print(f"Computing logits from text file: {args.text_path}")
tokens = model.tokenize(prompt.encode('utf-8'))
bos = model.token_bos()
b = 1 if bos is not None else 0
tokens = [tokens[i:i+ctx-b] for i in range(0, len(tokens), ctx-b)]
random.seed(123)
if bos is not None:
for i in range(len(tokens)):
tokens[i].insert(0, bos)
# Improves error estimation during calculation as context correlation to previous
# context is reduced compared to unshuffled order. Doesn't affect the final result.
random.shuffle(tokens)
write_file = None
if args.write_path is not None:
write_file = gzip.open(args.write_path, "wb")
write_header(write_file, args, model.n_ctx(), model.n_vocab(), model.n_batch)
def next_sample():
if read_file is not None:
while True:
try:
t, logits = read_logits(read_file)
except EOFError:
print("EOF at unexpected location")
return
if t is None:
return
yield logits, t
elif tokens is not None:
for t in tokens:
yield None, t
# Confidence interval bound
alpha = 0.01
kls = []
top1 = 0
top5 = 0
top10 = 0
eval_top5 = 0
eval_top10 = 0
samples = 0
written = 0
written_tokens = 0
i = 0
errors = 0
max_tokens = args.n_tokens
if max_tokens < 0:
max_tokens = float('inf')
try:
for logits, chunk in next_sample():
#print(model.detokenize(chunk))
model.reset()
output = model.eval(chunk)
eval_logits = model.eval_logits
if np.any(np.isnan(eval_logits)):
errors += 1
print("Nan in logits!")
eval_logits = np.nan_to_num(eval_logits)
if write_file:
write_logits(write_file, model.eval_tokens, eval_logits)
written_tokens += len(model.eval_tokens)
written += 1
print(f"[{written}/{len(tokens)}] tokens {written_tokens}")
if logits is not None:
# It would probably be better to throw away at least two first tokens
# in the context window since those are always the same. It doesn't
# matter that much though unlike in perplexity calculation since
# we are comparing to reference.
# This is really slow.
new_kls = [kl_div(eval_logits[i], logits[i]) for i in range(len(logits))]
if np.any(np.isnan(new_kls)):
errors += 1
print("Nan in computed kls!")
new_kls = np.nan_to_num(new_kls)
kls.extend(new_kls)
samples += len(logits)
# This is even slower.
eval_argmax = np.argmax(eval_logits, axis=-1)
ref_argmax = np.argmax(logits, axis=-1)
eval_part5 = np.argpartition(eval_logits, -5, axis=-1)[:,-5:]
ref_part5 = np.argpartition(logits, -5, axis=-1)[:,-5:]
eval_part10 = np.argpartition(eval_logits, -10, axis=-1)[:,-10:]
ref_part10 = np.argpartition(logits, -10, axis=-1)[:,-10:]
top1 += sum([eval_argmax[i] == ref_argmax[i] for i in range(len(logits))])
top5 += sum([ref_argmax[i] in eval_part5[i] for i in range(len(logits))])
top10 += sum([ref_argmax[i] in eval_part10[i] for i in range(len(logits))])
eval_top5 += sum([eval_argmax[i] in ref_part5[i] for i in range(len(logits))])
eval_top10 += sum([eval_argmax[i] in ref_part10[i] for i in range(len(logits))])
print(f"[{i}] kl {np.mean(kls):.4g}, top1 {top1 / samples:.4g}", flush=True)
i += 1
if samples >= max_tokens:
print("Token limit reached")
break
except KeyboardInterrupt:
print("Interrupted")
if write_file:
write_file.close()
print(f"Finished writing file: {args.write_path}")
if read_file:
read_file.close()
print(f"Finished reading file: {args.read_path}")
def bin_conf(p, n, z):
# Binomial distribution confidence bounds
# Bayes estimator when p is degenerate
if p == 0:
p = 1 / (n + 2)
if p == 1:
p = 1 - 1 / (n + 2)
return z * np.sqrt(p*(1-p)/n)
if len(kls) > 0:
z = student_t.ppf(1 - alpha/2, samples)
print()
print("Model:", model_name)
bpw = 8 * llama_cpp.llama_model_size(model.model) / llama_cpp.llama_model_n_params(model.model)
print(f"Size: {llama_cpp.llama_model_size(model.model) / 1024**3:.3g} GiB, (BPW {bpw:.2f})")
print("Tokens:", samples)
print("KL-divergence:")
# Confidence interval assuming i.i.d, but that likely isn't true.
m_conf = z*np.sqrt(np.mean([k**2 for k in kls])/len(kls))
m, _, __ = bayes_mvs(kls, 1-alpha)
print(f"mean: {m[0]:.6g}, [{m[1][0]:.6g} - {m[1][1]:.6g}]")
q90 = np.quantile(kls, 0.90)
q95 = np.quantile(kls, 0.95)
q99 = np.quantile(kls, 0.99)
q_bounds = mquantiles_cimj(kls, prob=[0.90, 0.95, 0.99])
print(f"q90: {q90:.4g}, [{q_bounds[0][0]:.4g} - {q_bounds[1][0]:.4g}]")
print(f"q95: {q95:.4g}, [{q_bounds[0][1]:.4g} - {q_bounds[1][1]:.4g}]")
print(f"q99: {q99:.4g}, [{q_bounds[0][2]:.4g} - {q_bounds[1][2]:.4g}]")
print(f"max: {np.max(kls):.4g}")
print("Reference top token in eval top-n probability:")
print(f"ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}")
print(f"ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}")
print(f"ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}")
print("Eval top token in reference top-n probability:")
print(f"eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}")
print(f"eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}")
print(f"errors: {errors}")
with open(model_name + ".kls.p", 'wb') as f:
pickle.dump(kls, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='llama.cpp KL-divergence',
description="Calculate KL-divergence of two models output logits on data set.\n"
"First call the program with write_path and text_path using fp16 model.\n"
"This writes logits to file. Then call the program with quantized model with read path\n"
"KL-divergence to the first run is calculated\n")
parser.add_argument('-m', '--model', help="Model path", required=True)
parser.add_argument('-t', '--text_path', help="Text dataset path", required=False)
parser.add_argument('-c', '--n_ctx', help="Context size", default=512, type=int, required=False)
parser.add_argument('-b', '--n_batch', help="Batch size", default=512, type=int, required=False)
parser.add_argument('-w', '--write_path', help="Output logits file", required=False)
parser.add_argument('-r', '--read_path', help="Input logits file", required=False)
parser.add_argument('-n', '--n_tokens', help="Number of tokens to evaluate. (-1 = whole file)", default=-1, type=int, required=False)
parser.add_argument('-ngl', '--n-gpu-layers', help="Number of GPU layers", default=0, type=int, required=False)
parser.add_argument('-v', '--verbose', help="Verbose output", action="store_true")
args = parser.parse_args()
if args.read_path is None and args.text_path is None:
print("Either text dataset or input logit file should be specified")
if args.write_path is None and args.read_path is None:
print("At least one of read_path or write_path needs to be specified")
sys.exit(1)
if args.write_path is not None and os.path.exists(args.write_path):
print(f"write_path {args.write_path} already exists")
sys.exit(1)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment