-
-
Save cinjon/de9a22f57cfa0dc9ccb2afc255a8093e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Test MMLU | |
Command: | |
1. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-9b" | |
--> all 0.7057399230878793 | |
2. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-9b-it" | |
--> all 0.6387266771115225 | |
3. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-27b-it" | |
--> all 0.7518159806295399 | |
4. python -m huggingface_test_gemma_base_mmlu --model_name="google/gemma-2-27b" | |
--> all 0.7517447657028913 | |
""" | |
from collections import defaultdict | |
import os | |
import tarfile | |
import pandas as pd | |
import requests | |
import torch | |
import tqdm | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def get_original_dataloader(split: str = "test", num_few_shot: int = 5): | |
extracted_path = "mmlu" | |
if not os.path.exists(os.path.join(extracted_path, "data")): | |
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" | |
response = requests.get(url, stream=True) | |
tar_file_path = url.split("/")[-1] | |
with open(tar_file_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
file.write(chunk) | |
with tarfile.open(tar_file_path, "r") as tar: | |
tar.extractall(path=extracted_path) | |
os.remove(tar_file_path) | |
# First get all the subjects. | |
subjects = sorted( | |
[f.split("_test.csv")[0] for f in os.listdir(os.path.join(extracted_path, "data", "test")) if "_test.csv" in f] | |
) | |
for s in subjects: | |
print(f"mmlu-fewshot5/chatformat-gemma/model-gemma-2-9b/{s}_test") | |
for num_subject, subject in enumerate(subjects): | |
tqdm.tqdm.write(f"Processing {subject}... {num_subject + 1}/{len(subjects)}") | |
dev_file_path = os.path.join(extracted_path, "data", "dev", f"{subject}_dev.csv") | |
dev_df = pd.read_csv(dev_file_path, header=None) | |
test_file_path = os.path.join(extracted_path, "data", split, f"{subject}_{split}.csv") | |
test_df = pd.read_csv(test_file_path, header=None) | |
for i in tqdm.tqdm(range(test_df.shape[0])): | |
curr = num_few_shot | |
while True: | |
if curr < 0: | |
# Don't do it if num few shot less than 0. This shouldn't happen though... | |
raise | |
prompt_end = format_example(test_df, i, include_answer=False) | |
subject_prompt = gen_prompt(dev_df, subject, num_few_shot=curr) | |
full = subject_prompt + prompt_end | |
yield {"subject": subject, "text_input": full, "answer": test_df.iloc[i, test_df.shape[1] - 1]} | |
break | |
def format_subject(subject): | |
return " ".join(subject.split("_")) | |
def format_example(df, idx, include_answer=True): | |
CHOICES = ["A", "B", "C", "D"] | |
prompt = df.iloc[idx, 0] | |
k = df.shape[1] - 2 | |
for j in range(k): | |
prompt += "\n{}. {}".format(CHOICES[j], df.iloc[idx, j + 1]) | |
prompt += "\nAnswer:" | |
if include_answer: | |
prompt += " {}\n\n".format(df.iloc[idx, k + 1]) | |
return prompt | |
def gen_prompt(train_df, subject, num_few_shot: int = 0): | |
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject)) | |
for i in range(num_few_shot): | |
prompt += format_example(train_df, i) | |
return prompt | |
def run(model_name, split: str = "test", num_few_shot: int = 5): | |
data_loader = get_original_dataloader(split, num_few_shot) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokens_of_interest = ["A", "B", "C", "D", " A", " B", " C", " D"] | |
tokens_of_interest_ids = [id[0] for id in tokenizer(tokens_of_interest, add_special_tokens=False)["input_ids"]] | |
correct = defaultdict(int) | |
total = defaultdict(int) | |
with torch.inference_mode(): | |
for num_ex, example in enumerate(data_loader): | |
text = example["text_input"] | |
answer = example["answer"] | |
subject = example["subject"] | |
if isinstance(text, list): | |
text = text[0] | |
answer = answer[0] | |
subject = subject[0] | |
if model_name.endswith("-it"): | |
messages = [ | |
{"role": "user", "content": text}, | |
] | |
text_tokenized = tokenizer.apply_chat_template( | |
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True | |
) | |
# text = f"<start_of_turn>user\n{text}<end_of_turn>\n<start_of_turn>model\n" | |
else: | |
text_tokenized = tokenizer(text, return_tensors="pt", add_special_tokens=True) | |
text_tokenized = text_tokenized.to(model.device) | |
outputs = model(**text_tokenized, return_dict=True) | |
logits = outputs.logits | |
logits = logits[:, -1, :] | |
probs = torch.softmax(logits, dim=-1).squeeze() | |
token_probs = probs[tokens_of_interest_ids] | |
argmax_value = torch.argmax(token_probs).item() | |
argmax_value = argmax_value % 4 | |
argmax_answer = tokens_of_interest[argmax_value] | |
correct[subject] += int(argmax_answer == answer) | |
total[subject] += 1 | |
correct["all"] += int(argmax_answer == answer) | |
total["all"] += 1 | |
if num_ex % 250 == 0 and num_ex > 0: | |
tqdm.tqdm.write(f"Step {num_ex}") | |
for subject in sorted(correct.keys()): | |
tqdm.tqdm.write(f"{subject}: {correct[subject] / total[subject]}") | |
for subject in sorted(correct.keys()): | |
print(subject, correct[subject] / total[subject]) | |
if __name__ == "__main__": | |
import fire | |
import dotenv | |
dotenv.load_dotenv() | |
fire.Fire(run) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment