Skip to content

Instantly share code, notes, and snippets.

@cinjon
Last active October 9, 2024 21:34
Show Gist options
  • Save cinjon/de9a22f57cfa0dc9ccb2afc255a8093e to your computer and use it in GitHub Desktop.
Save cinjon/de9a22f57cfa0dc9ccb2afc255a8093e to your computer and use it in GitHub Desktop.
"""Test MMLU
Command:
1. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-9b"
--> all 0.7057399230878793
2. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-9b-it"
--> all 0.6387266771115225
3. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-27b-it"
--> all 0.7518159806295399
4. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-27b"
--> all 0.7517447657028913
"""
from collections import defaultdict
import os
import tarfile
import pandas as pd
import requests
import torch
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_original_dataloader(split: str = "test", num_few_shot: int = 5):
extracted_path = "mmlu"
if not os.path.exists(os.path.join(extracted_path, "data")):
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
response = requests.get(url, stream=True)
tar_file_path = url.split("/")[-1]
with open(tar_file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
with tarfile.open(tar_file_path, "r") as tar:
tar.extractall(path=extracted_path)
os.remove(tar_file_path)
# First get all the subjects.
subjects = sorted(
[f.split("_test.csv")[0] for f in os.listdir(os.path.join(extracted_path, "data", "test")) if "_test.csv" in f]
)
for s in subjects:
print(f"mmlu-fewshot5/chatformat-gemma/model-gemma-2-9b/{s}_test")
for num_subject, subject in enumerate(subjects):
tqdm.tqdm.write(f"Processing {subject}... {num_subject + 1}/{len(subjects)}")
dev_file_path = os.path.join(extracted_path, "data", "dev", f"{subject}_dev.csv")
dev_df = pd.read_csv(dev_file_path, header=None)
test_file_path = os.path.join(extracted_path, "data", split, f"{subject}_{split}.csv")
test_df = pd.read_csv(test_file_path, header=None)
for i in tqdm.tqdm(range(test_df.shape[0])):
curr = num_few_shot
while True:
if curr < 0:
# Don't do it if num few shot less than 0. This shouldn't happen though...
raise
prompt_end = format_example(test_df, i, include_answer=False)
subject_prompt = gen_prompt(dev_df, subject, num_few_shot=curr)
full = subject_prompt + prompt_end
yield {"subject": subject, "text_input": full, "answer": test_df.iloc[i, test_df.shape[1] - 1]}
break
def format_subject(subject):
return " ".join(subject.split("_"))
def format_example(df, idx, include_answer=True):
CHOICES = ["A", "B", "C", "D"]
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(CHOICES[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(train_df, subject, num_few_shot: int = 0):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
for i in range(num_few_shot):
prompt += format_example(train_df, i)
return prompt
def run(model_name, split: str = "test", num_few_shot: int = 5):
data_loader = get_original_dataloader(split, num_few_shot)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens_of_interest = ["A", "B", "C", "D", " A", " B", " C", " D"]
tokens_of_interest_ids = [id[0] for id in tokenizer(tokens_of_interest, add_special_tokens=False)["input_ids"]]
correct = defaultdict(int)
total = defaultdict(int)
with torch.inference_mode():
for num_ex, example in enumerate(data_loader):
text = example["text_input"]
answer = example["answer"]
subject = example["subject"]
if isinstance(text, list):
text = text[0]
answer = answer[0]
subject = subject[0]
if model_name.endswith("-it"):
messages = [
{"role": "user", "content": text},
]
text_tokenized = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
)
# text = f"<start_of_turn>user\n{text}<end_of_turn>\n<start_of_turn>model\n"
else:
text_tokenized = tokenizer(text, return_tensors="pt", add_special_tokens=True)
text_tokenized = text_tokenized.to(model.device)
outputs = model(**text_tokenized, return_dict=True)
logits = outputs.logits
logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1).squeeze()
token_probs = probs[tokens_of_interest_ids]
argmax_value = torch.argmax(token_probs).item()
argmax_value = argmax_value % 4
argmax_answer = tokens_of_interest[argmax_value]
correct[subject] += int(argmax_answer == answer)
total[subject] += 1
correct["all"] += int(argmax_answer == answer)
total["all"] += 1
if num_ex % 250 == 0 and num_ex > 0:
tqdm.tqdm.write(f"Step {num_ex}")
for subject in sorted(correct.keys()):
tqdm.tqdm.write(f"{subject}: {correct[subject] / total[subject]}")
for subject in sorted(correct.keys()):
print(subject, correct[subject] / total[subject])
if __name__ == "__main__":
import fire
import dotenv
dotenv.load_dotenv()
fire.Fire(run)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment