Skip to content

Instantly share code, notes, and snippets.

@csarron
Created April 5, 2023 20:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csarron/a2a90f2b1a143aaf89b49f296118fa45 to your computer and use it in GitHub Desktop.
Save csarron/a2a90f2b1a143aaf89b49f296118fa45 to your computer and use it in GitHub Desktop.
import datetime
import json
import re
import string
import unicodedata
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModelForCausalLM
import torch
import time
import fire
from loguru import logger
from tqdm import tqdm
import random
log = logger.info
def remove_accents(input_str):
nfkd_form = unicodedata.normalize("NFKD", input_str)
return "".join([c for c in nfkd_form if not unicodedata.combining(c)])
def normalize_answer(text: str) -> str:
# text = unicodedata.normalize("NFD", text)
text = remove_accents(text)
text = text.lower()
text = " ".join(c for c in text if c not in frozenset(string.punctuation))
text = re.sub(r"\b(a|an|the)\b", " ", text)
text = " ".join(text.split())
return text
def generate(tokenizer, prompt, model, max_new_tokens=10, temperature=0.8, top_p=0.95):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return decoded[len(prompt):]
def setup_model(model_path, tokenizer_path, lora_path=None):
log("loading model...")
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16)
log("loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
if added_tokens > 0:
model.resize_token_embeddings(len(tokenizer))
if lora_path is not None:
log("loading lora model..")
model = PeftModelForCausalLM.from_pretrained(model, lora_path, device_map="auto", torch_dtype=torch.float16)
model.to(dtype=torch.float16)
log(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
return model, tokenizer
def extract_answer(text):
is_list_item = False
if text.startswith("1."):
is_list_item = True
text = text.replace("1. ", "") # TODO: still needs to properly extract answers
end_idx = len(text)
for char in ['\n', '.', ',']:
idx = text.find(char)
if idx != -1 and idx < end_idx:
end_idx = min(end_idx, idx)
answer = text[:end_idx]
if answer.endswith("2") and is_list_item:
answer = answer[-2:].strip()
return answer
def zero_shot_close_qa(dataset_file, model_path, tokenizer_path, lora_path=None, max_new_tokens=30, temperature=0.8, top_p=0.95):
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path)
log(f"loading data from {dataset_file}...")
qa_data = [json.loads(x) for x in open(dataset_file)]
start_time = time.time()
correct_count = 0
p_bar = tqdm(qa_data)
for qa_item in p_bar:
question = qa_item["question"]
answers = qa_item["answers"]
prompt = f"Answer these questions: \nQ: {question}\nA: "
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p)
# pred_ans = extract_answer(pred_text)
# is_correct = normalize_answer(pred_ans) in frozenset(normalize_answer(ans) for ans in answers)
pred_ans = normalize_answer(pred_text)
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers)
correct_count += int(is_correct)
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
duration = time.time() - start_time
duration_str = datetime.timedelta(seconds=duration)
acc = correct_count / len(qa_data) * 100
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!")
def zero_shot_open_qa(dataset_file, model_path, tokenizer_path, lora_path=None, top_k=5, max_new_tokens=30, temperature=0.8, top_p=0.95):
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path)
log(f"loading data from {dataset_file}...")
qa_data = [json.loads(x) for x in open(dataset_file)]
start_time = time.time()
correct_count = 0
p_bar = tqdm(qa_data)
for qa_item in p_bar:
question = qa_item["question"]
answers = qa_item["answers"]
contexts = qa_item["ctxs"][:top_k]
passages = [c["text"] for c in contexts]
psg_text = "\n".join(passages)
prompt = f"Given the following passages: \n{psg_text}\nAnswer the question: {question}\nThe answer is "
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p)
pred_ans = normalize_answer(pred_text)
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers)
correct_count += int(is_correct)
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
duration = time.time() - start_time
duration_str = datetime.timedelta(seconds=duration)
acc = correct_count / len(qa_data) * 100
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!")
def few_shot_close_qa(dataset_file, train_file, model_path, tokenizer_path, lora_path=None, shot=5, seed=0, max_new_tokens=30, temperature=0.8, top_p=0.95):
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path)
log(f"loading data from {dataset_file}...")
qa_data = [json.loads(x) for x in open(dataset_file)]
# sample shot number of examples from train_data
train_data = [json.loads(x) for x in open(train_file)]
random.seed(seed)
sample_train = random.sample(train_data, shot)
sample_text = "\n".join([f'Q: {x["question"]}\nA: {x["answers"][0]}' for x in sample_train])
log(f"{shot}-shot examples: {sample_text}")
start_time = time.time()
correct_count = 0
p_bar = tqdm(qa_data)
for qa_item in p_bar:
question = qa_item["question"]
answers = qa_item["answers"]
prompt = f"Answer these questions: \n{sample_text}\nQ: {question}\nA: "
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p)
pred_ans = normalize_answer(pred_text)
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers)
correct_count += int(is_correct)
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
duration = time.time() - start_time
duration_str = datetime.timedelta(seconds=duration)
acc = correct_count / len(qa_data) * 100
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!")
def few_shot_open_qa(dataset_file, train_file, model_path, tokenizer_path, lora_path=None, top_k=5, shot=5, seed=0, max_new_tokens=30, temperature=0.8, top_p=0.95):
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path)
log(f"loading data from {dataset_file}...")
qa_data = [json.loads(x) for x in open(dataset_file)]
# sample shot number of examples from train_data
train_data = [json.loads(x) for x in open(train_file)]
random.seed(seed)
sample_train = random.sample(train_data, shot)
sample_texts = []
for item in sample_train:
psg_text = "\n".join([c["text"] for c in item["ctxs"][:top_k]])
sample_t = f'{psg_text}\nQ: {item["question"]}\nA: {item["answers"][0]}'
sample_texts.append(sample_t)
sample_text = "\n".join(sample_texts)
log(f"{shot}-shot examples: {sample_text}")
start_time = time.time()
correct_count = 0
p_bar = tqdm(qa_data)
for qa_item in p_bar:
question = qa_item["question"]
answers = qa_item["answers"]
prompt = f"\n{sample_text}\nQ: {question}\nA: "
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p)
pred_ans = normalize_answer(pred_text)
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers)
correct_count += int(is_correct)
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", ""))
duration = time.time() - start_time
duration_str = datetime.timedelta(seconds=duration)
acc = correct_count / len(qa_data) * 100
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!")
if __name__ == "__main__":
fire.Fire()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment