Last active
May 9, 2024 14:06
-
-
Save zucchini-nlp/a7b19ec32f8c402761d48f3736eac808 to your computer and use it in GitHub Desktop.
Calculate the perplexity of Llama with different cache implementations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adapted from https://github.com/mit-han-lab/streaming-llm | |
Note: Although this script measures latency, it is not optimized whatsoever! | |
The latency is only tracked to see the impact of speed over time. | |
Usage: | |
python benchmark/perplexity.py --experiment dynamicCacheInt4 --cache_implementation dynamic | |
python benchmark/perplexity.py --experiment quantCacheInt4 --cache_implementation quantized --nbits 2 | |
python benchmark/perplexity.py --experiment quantCacheInt4 --cache_implementation quantized --nbits 4 | |
Plot perplexity after obtaining all .csv files by running main_plot(): | |
python benchmark/plot_perplexity.py | |
The script is tested on https://github.com/zucchini-nlp/transformers/tree/quant (commit_id 5f3046a) | |
Thanks for the script to [Clementine](https://huggingface.co/clefourrier) and [Joao](https://huggingface.co/joaogante) | |
""" | |
import argparse | |
import time | |
import itertools | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import List, Optional | |
import numpy as np | |
import pandas as pd | |
from matplotlib import pyplot as plt | |
import torch | |
from datasets import load_dataset | |
from torch.nn import CrossEntropyLoss | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.cache_utils import QuantCache | |
def plot( | |
output_dir: str = "outputs", | |
title: Optional[str] = None, | |
perplexity_limit: Optional[float] = None, | |
skip_first: int = 100, | |
): | |
output_dir = Path(output_dir) | |
fig, ax = plt.subplots() | |
ax.set_xlabel("Input Sequence Length") | |
for file in output_dir.glob("*.csv"): | |
experiment = file.stem | |
df = pd.read_csv(file) | |
df = df.groupby(['input_length']).mean() | |
X = df.index[skip_first:] | |
Y = df["overall_ppl"][skip_first:] | |
Y = np.log(Y) | |
ax.plot(X, Y, "-", label=f"{experiment} perplexity") | |
ax.set_ylabel("Perplexity (log), lower is better") | |
if perplexity_limit: | |
ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit)) | |
ax.legend(loc=[1, 2, 7][0]) # upper right, upper left, center right | |
ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths") | |
fig.tight_layout() | |
return fig | |
def compute_perplexity( | |
model, | |
tokenizer, | |
dataset, | |
experiment: str, | |
cache_implementation: str, | |
output_dir: str = "outputs", | |
data_column: str = "text", | |
num_samples: Optional[int] = 1, | |
num_tokens: Optional[int] = None, | |
overwrite: bool = False, | |
) -> None: | |
output_dir = Path(output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
output_file = output_dir / f"{experiment}.csv" | |
if output_file.exists() and not overwrite: | |
raise ValueError( | |
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`." | |
) | |
logs = defaultdict(list) | |
loss_fn = CrossEntropyLoss(reduction="none") | |
num_data_elements = 0 | |
for data_element in itertools.islice(dataset, num_samples): | |
encodings = tokenizer(data_element[data_column], return_tensors="pt") | |
seq_len = encodings.input_ids.size(1) | |
pbar = tqdm(range(0, seq_len - 1)) | |
num_processed_tokens = 0 | |
if cache_implementation == "quantized": | |
past_key_values = QuantCache(nbits=4) | |
else: | |
past_key_values = None | |
for idx in pbar: | |
start_t = time.time() | |
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device) | |
with torch.no_grad(): | |
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) | |
logits = outputs.logits.view(-1, model.config.vocab_size) | |
past_key_values = outputs.past_key_values | |
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1) | |
neg_log_likelihood = loss_fn(logits, label) | |
perplexity = neg_log_likelihood.exp() | |
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}") | |
# Store data and save every 10 tokens | |
logs["data_idx"].append(num_data_elements + 1) | |
logs["input_length"].append(idx + 1) | |
logs["nll"].append(neg_log_likelihood.item()) | |
logs["ppl"].append(perplexity.item()) | |
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item()) | |
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB | |
logs["latency"].append(time.time() - start_t) | |
if num_processed_tokens % 10 == 0: | |
try: | |
pd.DataFrame(logs).to_csv(output_file, index=False) | |
except KeyboardInterrupt as ex: | |
# If there's a Keyboard Interrupt, still write the file, and then stop | |
pd.DataFrame(logs).to_csv(output_file, index=False) | |
raise ex | |
num_processed_tokens += 1 | |
if num_tokens and num_processed_tokens >= num_tokens: | |
break | |
num_data_elements += 1 | |
def main_plot(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--output_dir", type=str, default="./outputs") | |
parser.add_argument("--title", type=str, default=None) | |
parser.add_argument("--log_perplexity_limit", type=float, default=5.0) | |
# Perplexity starts a bit unstable, so we skip the start | |
parser.add_argument("--skip_first", type=int, default=100) | |
args = parser.parse_args() | |
figure = plot( | |
args.features, | |
output_dir=args.output_dir, | |
title=args.title, | |
perplexity_limit=args.log_perplexity_limit, | |
skip_first=args.skip_first, | |
) | |
# Add your own code here if you'd like to change the figure | |
features = "_".join(args.features) | |
save_path = f"./outputs/plot_{features}.png" | |
plt.savefig(save_path, dpi=600) | |
print(f"plot saved to {save_path}") | |
def main(): | |
parser = argparse.ArgumentParser() | |
# How to call this experiment? | |
parser.add_argument( | |
"--experiment", type=str, default="main" | |
) | |
parser.add_argument("--cache_implementation", type=str, default="quantized") | |
# Model args | |
# parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") | |
parser.add_argument("--revision", type=str, default="main") | |
parser.add_argument("--trust_remote_code", action="store_true") | |
# Dataset args | |
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test") | |
parser.add_argument("--data_column", type=str, default="text") | |
parser.add_argument("--task", type=str, default=None) | |
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"]) | |
parser.add_argument("--num_samples", type=int, default=1) | |
parser.add_argument("--num_tokens", type=int, default=5000) | |
parser.add_argument("--dtype", type=str, default="fp16") | |
# Where to log | |
parser.add_argument("--output_dir", type=str, default="/home/raushan/perplexity/outputs") | |
parser.add_argument("--overwrite", action="store_true") | |
args = parser.parse_args() | |
if args.dtype == "fp16": | |
dtype = torch.float16 | |
elif args.dtype == "fp32": | |
dtype = torch.float32 | |
elif args.dtype == "bf16": | |
dtype = torch.bfloat16 | |
else: | |
raise ValueError(f"Unknown dtype: {args.dtype}") | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
revision=args.revision, | |
trust_remote_code=bool(args.trust_remote_code), | |
attn_implementation="eager", | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code)) | |
# Set up the dataset | |
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True) | |
compute_perplexity( | |
model, | |
tokenizer, | |
dataset, | |
args.experiment, | |
args.cache_implementation, | |
output_dir=args.output_dir, | |
data_column=args.data_column, | |
num_samples=args.num_samples, | |
num_tokens=args.num_tokens, | |
overwrite=args.overwrite, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment