Skip to content

Instantly share code, notes, and snippets.

@zucchini-nlp
Last active May 9, 2024 14:06
Show Gist options
  • Save zucchini-nlp/a7b19ec32f8c402761d48f3736eac808 to your computer and use it in GitHub Desktop.
Save zucchini-nlp/a7b19ec32f8c402761d48f3736eac808 to your computer and use it in GitHub Desktop.
Calculate the perplexity of Llama with different cache implementations
"""
Adapted from https://github.com/mit-han-lab/streaming-llm
Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.
Usage:
python benchmark/perplexity.py --experiment dynamicCacheInt4 --cache_implementation dynamic
python benchmark/perplexity.py --experiment quantCacheInt4 --cache_implementation quantized --nbits 2
python benchmark/perplexity.py --experiment quantCacheInt4 --cache_implementation quantized --nbits 4
Plot perplexity after obtaining all .csv files by running main_plot():
python benchmark/plot_perplexity.py
The script is tested on https://github.com/zucchini-nlp/transformers/tree/quant (commit_id 5f3046a)
Thanks for the script to [Clementine](https://huggingface.co/clefourrier) and [Joao](https://huggingface.co/joaogante)
"""
import argparse
import time
import itertools
from collections import defaultdict
from pathlib import Path
from typing import List, Optional
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import QuantCache
def plot(
output_dir: str = "outputs",
title: Optional[str] = None,
perplexity_limit: Optional[float] = None,
skip_first: int = 100,
):
output_dir = Path(output_dir)
fig, ax = plt.subplots()
ax.set_xlabel("Input Sequence Length")
for file in output_dir.glob("*.csv"):
experiment = file.stem
df = pd.read_csv(file)
df = df.groupby(['input_length']).mean()
X = df.index[skip_first:]
Y = df["overall_ppl"][skip_first:]
Y = np.log(Y)
ax.plot(X, Y, "-", label=f"{experiment} perplexity")
ax.set_ylabel("Perplexity (log), lower is better")
if perplexity_limit:
ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))
ax.legend(loc=[1, 2, 7][0]) # upper right, upper left, center right
ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
fig.tight_layout()
return fig
def compute_perplexity(
model,
tokenizer,
dataset,
experiment: str,
cache_implementation: str,
output_dir: str = "outputs",
data_column: str = "text",
num_samples: Optional[int] = 1,
num_tokens: Optional[int] = None,
overwrite: bool = False,
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"{experiment}.csv"
if output_file.exists() and not overwrite:
raise ValueError(
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
)
logs = defaultdict(list)
loss_fn = CrossEntropyLoss(reduction="none")
num_data_elements = 0
for data_element in itertools.islice(dataset, num_samples):
encodings = tokenizer(data_element[data_column], return_tensors="pt")
seq_len = encodings.input_ids.size(1)
pbar = tqdm(range(0, seq_len - 1))
num_processed_tokens = 0
if cache_implementation == "quantized":
past_key_values = QuantCache(nbits=4)
else:
past_key_values = None
for idx in pbar:
start_t = time.time()
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
perplexity = neg_log_likelihood.exp()
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")
# Store data and save every 10 tokens
logs["data_idx"].append(num_data_elements + 1)
logs["input_length"].append(idx + 1)
logs["nll"].append(neg_log_likelihood.item())
logs["ppl"].append(perplexity.item())
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB
logs["latency"].append(time.time() - start_t)
if num_processed_tokens % 10 == 0:
try:
pd.DataFrame(logs).to_csv(output_file, index=False)
except KeyboardInterrupt as ex:
# If there's a Keyboard Interrupt, still write the file, and then stop
pd.DataFrame(logs).to_csv(output_file, index=False)
raise ex
num_processed_tokens += 1
if num_tokens and num_processed_tokens >= num_tokens:
break
num_data_elements += 1
def main_plot():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="./outputs")
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
# Perplexity starts a bit unstable, so we skip the start
parser.add_argument("--skip_first", type=int, default=100)
args = parser.parse_args()
figure = plot(
args.features,
output_dir=args.output_dir,
title=args.title,
perplexity_limit=args.log_perplexity_limit,
skip_first=args.skip_first,
)
# Add your own code here if you'd like to change the figure
features = "_".join(args.features)
save_path = f"./outputs/plot_{features}.png"
plt.savefig(save_path, dpi=600)
print(f"plot saved to {save_path}")
def main():
parser = argparse.ArgumentParser()
# How to call this experiment?
parser.add_argument(
"--experiment", type=str, default="main"
)
parser.add_argument("--cache_implementation", type=str, default="quantized")
# Model args
# parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
parser.add_argument("--data_column", type=str, default="text")
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--num_tokens", type=int, default=5000)
parser.add_argument("--dtype", type=str, default="fp16")
# Where to log
parser.add_argument("--output_dir", type=str, default="/home/raushan/perplexity/outputs")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "bf16":
dtype = torch.bfloat16
else:
raise ValueError(f"Unknown dtype: {args.dtype}")
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
attn_implementation="eager",
torch_dtype=dtype,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
# Set up the dataset
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)
compute_perplexity(
model,
tokenizer,
dataset,
args.experiment,
args.cache_implementation,
output_dir=args.output_dir,
data_column=args.data_column,
num_samples=args.num_samples,
num_tokens=args.num_tokens,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment