Last active
January 16, 2024 23:30
-
-
Save quinnhj/c2c14e6457fe1ecdd0d590562954aac0 to your computer and use it in GitHub Desktop.
Tidepool RAG Prompting Blog Source
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from collections import defaultdict | |
from openai import OpenAI | |
from pydantic import BaseModel, Field | |
from typing import Set, List | |
import json | |
file_dir_path = os.path.dirname(os.path.realpath(__file__)) | |
############################################################################## | |
# OpenAI Setup | |
############################################################################## | |
client = OpenAI( | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
max_retries=3, | |
timeout=30, | |
) | |
MODEL_NAME = 'gpt-3.5-turbo-16k' | |
def call_model(prompt_body): | |
# print(prompt_body) | |
chat_completion = client.chat.completions.create( | |
temperature=0.7, | |
messages=[{ "role": "user", "content": prompt_body }], | |
model=MODEL_NAME | |
) | |
return chat_completion.choices[0].message.content | |
############################################################################## | |
# Trivia QA Data Loading | |
############################################################################## | |
MAX_WEB_RESULTS = 3 | |
# This should be a directory containing the dataset available here: https://nlp.cs.washington.edu/triviaqa/ | |
TRIVIA_QA_DIR = '/path/to/data/triviaqa/' | |
web_file = os.path.join(TRIVIA_QA_DIR, 'qa/verified-web-dev.json') | |
web_evidence_dir = os.path.join(TRIVIA_QA_DIR, 'evidence/web/') | |
class TriviaEntry(BaseModel): | |
question: str = Field(default='') | |
canonical_answer: str = Field(default='') | |
answers: Set[str] = Field(default_factory=lambda : set()) | |
web_filenames: List[str] = Field(default_factory=lambda : list()) | |
trivia_entries = defaultdict(TriviaEntry) | |
with open(web_file, 'r') as f: | |
data = json.load(f) | |
for info in data['Data']: | |
if len(info['SearchResults']) == 0: | |
continue | |
trivia_entry = trivia_entries[info['QuestionId']] | |
trivia_entry.question = info['Question'] | |
for search_result in info['SearchResults'][:MAX_WEB_RESULTS]: | |
trivia_entry.web_filenames.append(os.path.join(web_evidence_dir, search_result['Filename'])) | |
for alias in info['Answer']['Aliases']: | |
trivia_entry.answers.add(alias) | |
for normalized_alias in info['Answer']['NormalizedAliases']: | |
trivia_entry.answers.add(normalized_alias) | |
trivia_entry.answers.add(info['Answer']['Value']) | |
trivia_entry.canonical_answer = info['Answer']['Value'] | |
# To be safe, an explicit round of normalizing all answers to lowercase | |
trivia_entry.answers = set([x.lower() for x in trivia_entry.answers]) | |
# Use a fixed, randomly generated list of 100 trivia question ids | |
# For the purposes of this exercise, these have been filtered to ones | |
# where the supporting documents will fit within a 16k token context. | |
with open(os.path.join(file_dir_path, 'trivia_qa_test_ids.json'), 'r') as f: | |
trivia_ids = json.load(f) | |
############################################################################## | |
# Prompts | |
############################################################################## | |
def parse_answer(text): | |
try: | |
return json.loads(text)['answer'] | |
except: | |
return "Didn't Produce Parseable Answer" | |
def load_and_format_documents(trivia_entry): | |
web_content = [] | |
for fn in trivia_entry.web_filenames: | |
with open(fn, 'r') as f: | |
web_content.append(f.read()) | |
return "\n\n\n".join(web_content) | |
def make_prompt_new(trivia_entry): | |
documents = load_and_format_documents(trivia_entry) | |
return f""" | |
Please answer the following question, using the following documents. | |
Documents: | |
{documents} | |
Question: | |
{trivia_entry.question} | |
Write your answer in the json form: | |
{{ | |
"answer": "your answer" | |
}} | |
Make sure your answer is just the answer in json form, with no commentary. | |
Start!""" | |
def make_prompt_old(trivia_entry): | |
documents = load_and_format_documents(trivia_entry) | |
return f""" | |
Please answer the following question, using the following documents. | |
Question: | |
{trivia_entry.question} | |
Documents: | |
{documents} | |
Write your answer in the json form: | |
{{ | |
"answer": "your answer" | |
}} | |
Make sure your answer is just the answer in json form, with no commentary. | |
Start!""" | |
make_prompt = make_prompt_new | |
############################################################################## | |
# Iterate! | |
############################################################################## | |
def grade(candidate, answers): | |
lowered_candidate = candidate.lower() | |
for answer in answers: | |
if lowered_candidate and (lowered_candidate in answer or answer in lowered_candidate): | |
return True | |
return False | |
answers = {} | |
correctness = {} | |
bad_keys = [] | |
for trivia_id in trivia_ids: | |
trivia_entry = trivia_entries[trivia_id] | |
prompt = make_prompt(trivia_entry) | |
resp = call_model(prompt) | |
# print("\n\nModel Resp:\n") | |
# print(resp) | |
answer = parse_answer(resp) | |
answers[trivia_id] = answer | |
correctness[trivia_id] = grade(answer, trivia_entry.answers) | |
############################################################################## | |
# Print Report | |
############################################################################## | |
total = len(correctness) | |
num_correct = len([x for x in correctness.values() if x]) | |
num_incorrect = len([x for x in correctness.values() if not x]) | |
accuracy = num_correct / total | |
performance_summary = { | |
'total': total, | |
'num_correct': num_correct, | |
'num_incorrect': num_incorrect, | |
'accuracy': f"{accuracy:.2f}" | |
} | |
incorrect_summaries = {} | |
for trivia_id, is_correct in correctness.items(): | |
if not is_correct: | |
incorrect_summaries[trivia_id] = { | |
'question': trivia_entries[trivia_id].question, | |
'llm_answer': answers[trivia_id], | |
'correct_answer': trivia_entries[trivia_id].canonical_answer, | |
'correct_answer_aliases': list(trivia_entries[trivia_id].answers), | |
'provided_documents': trivia_entries[trivia_id].web_filenames | |
} | |
print(json.dumps(incorrect_summaries, indent=4)) | |
print(json.dumps(performance_summary, indent=4)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
openai==1.3.7 | |
pydantic==2.5.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ | |
"qf_1346", | |
"qz_77", | |
"qb_7836", | |
"jp_1823", | |
"sfq_16679", | |
"sfq_19250", | |
"qb_4673", | |
"tc_1516", | |
"sfq_7904", | |
"wh_4111", | |
"jp_2221", | |
"qw_3959", | |
"sfq_4360", | |
"odql_10794", | |
"bb_2238", | |
"sfq_10664", | |
"bb_7027", | |
"sfq_7702", | |
"qf_1371", | |
"sfq_3917", | |
"tc_280", | |
"bt_4555", | |
"dpql_525", | |
"qg_307", | |
"qw_2252", | |
"sfq_18532", | |
"bt_4546", | |
"wh_2660", | |
"sfq_18423", | |
"sfq_12016", | |
"sfq_18876", | |
"qg_2301", | |
"sfq_11200", | |
"bb_5721", | |
"sfq_5452", | |
"sfq_15080", | |
"sfq_1299", | |
"odql_351", | |
"qb_6583", | |
"sfq_18100", | |
"sfq_5366", | |
"tc_1535", | |
"odql_2323", | |
"qf_441", | |
"qw_2583", | |
"bb_1257", | |
"tc_2687", | |
"qb_9820", | |
"sfq_24876", | |
"qw_6866", | |
"sfq_9324", | |
"qz_2397", | |
"tc_1156", | |
"odql_9441", | |
"sfq_21166", | |
"sfq_11761", | |
"qb_7942", | |
"sfq_346", | |
"jp_478", | |
"sfq_23282", | |
"wh_1933", | |
"dpql_5441", | |
"sfq_8549", | |
"tc_261", | |
"tc_69", | |
"qz_4056", | |
"qw_1116", | |
"qz_2444", | |
"qw_2865", | |
"qb_4444", | |
"qb_6048", | |
"odql_7686", | |
"wh_2442", | |
"sfq_1979", | |
"wh_2629", | |
"qb_7858", | |
"jp_419", | |
"bt_1627", | |
"qb_9094", | |
"qb_2764", | |
"odql_6739", | |
"qg_4089", | |
"wh_2003", | |
"bb_7920", | |
"odql_3849", | |
"bb_4336", | |
"sfq_1110", | |
"odql_424", | |
"dpql_5884", | |
"qb_5532", | |
"bb_200", | |
"qg_3158", | |
"qz_5083", | |
"sfq_23994", | |
"qw_459", | |
"sfq_18420", | |
"sfq_7997", | |
"qg_3452", | |
"qb_4849", | |
"qz_1794" | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment