Skip to content

Instantly share code, notes, and snippets.

@alfredplpl
Last active April 2, 2024 01:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alfredplpl/4f195efa3dfd5475b786f2ab94377232 to your computer and use it in GitHub Desktop.
Save alfredplpl/4f195efa3dfd5475b786f2ab94377232 to your computer and use it in GitHub Desktop.
CALM2をJMMLUで評価してみたい人用のスクリプト
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
import pandas
model_name_or_path = "cyberagent/calm2-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="cpu", torch_dtype=torch.float32)
# https://github.com/nlp-waseda/JMMLU/blob/main/JMMLU/college_computer_science.csv
df=pandas.read_csv("college_computer_science.csv",header=None)
correct=0
total=0
for i,row in df.iterrows():
prompt = f"""USER:{row[0]} 次の選択肢の中からA,B,C,Dのいずれかだけ答えなさい。
A. {row[1]}
B. {row[2]}
C. {row[3]}
D. {row[4]}
ASSISTANT:"""
input_ids = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**input_ids,
max_new_tokens=1,
do_sample=True,
top_p=0.95,
temperature=0.2,
repetition_penalty=1.1,
)
sentence=tokenizer.decode(outputs[0])
print(sentence,f"ANSWER:{row[5]}")
if(sentence[-1]==row[5]):
correct+=1
total+=1
print("rate:",correct/total)
# rate: 0.23232323232323232
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment