Skip to content

Instantly share code, notes, and snippets.

@quinnhj
Last active January 16, 2024 23:30
Show Gist options
  • Save quinnhj/c2c14e6457fe1ecdd0d590562954aac0 to your computer and use it in GitHub Desktop.
Save quinnhj/c2c14e6457fe1ecdd0d590562954aac0 to your computer and use it in GitHub Desktop.
Tidepool RAG Prompting Blog Source
import os
from collections import defaultdict
from openai import OpenAI
from pydantic import BaseModel, Field
from typing import Set, List
import json
file_dir_path = os.path.dirname(os.path.realpath(__file__))
##############################################################################
# OpenAI Setup
##############################################################################
client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
max_retries=3,
timeout=30,
)
MODEL_NAME = 'gpt-3.5-turbo-16k'
def call_model(prompt_body):
# print(prompt_body)
chat_completion = client.chat.completions.create(
temperature=0.7,
messages=[{ "role": "user", "content": prompt_body }],
model=MODEL_NAME
)
return chat_completion.choices[0].message.content
##############################################################################
# Trivia QA Data Loading
##############################################################################
MAX_WEB_RESULTS = 3
# This should be a directory containing the dataset available here: https://nlp.cs.washington.edu/triviaqa/
TRIVIA_QA_DIR = '/path/to/data/triviaqa/'
web_file = os.path.join(TRIVIA_QA_DIR, 'qa/verified-web-dev.json')
web_evidence_dir = os.path.join(TRIVIA_QA_DIR, 'evidence/web/')
class TriviaEntry(BaseModel):
question: str = Field(default='')
canonical_answer: str = Field(default='')
answers: Set[str] = Field(default_factory=lambda : set())
web_filenames: List[str] = Field(default_factory=lambda : list())
trivia_entries = defaultdict(TriviaEntry)
with open(web_file, 'r') as f:
data = json.load(f)
for info in data['Data']:
if len(info['SearchResults']) == 0:
continue
trivia_entry = trivia_entries[info['QuestionId']]
trivia_entry.question = info['Question']
for search_result in info['SearchResults'][:MAX_WEB_RESULTS]:
trivia_entry.web_filenames.append(os.path.join(web_evidence_dir, search_result['Filename']))
for alias in info['Answer']['Aliases']:
trivia_entry.answers.add(alias)
for normalized_alias in info['Answer']['NormalizedAliases']:
trivia_entry.answers.add(normalized_alias)
trivia_entry.answers.add(info['Answer']['Value'])
trivia_entry.canonical_answer = info['Answer']['Value']
# To be safe, an explicit round of normalizing all answers to lowercase
trivia_entry.answers = set([x.lower() for x in trivia_entry.answers])
# Use a fixed, randomly generated list of 100 trivia question ids
# For the purposes of this exercise, these have been filtered to ones
# where the supporting documents will fit within a 16k token context.
with open(os.path.join(file_dir_path, 'trivia_qa_test_ids.json'), 'r') as f:
trivia_ids = json.load(f)
##############################################################################
# Prompts
##############################################################################
def parse_answer(text):
try:
return json.loads(text)['answer']
except:
return "Didn't Produce Parseable Answer"
def load_and_format_documents(trivia_entry):
web_content = []
for fn in trivia_entry.web_filenames:
with open(fn, 'r') as f:
web_content.append(f.read())
return "\n\n\n".join(web_content)
def make_prompt_new(trivia_entry):
documents = load_and_format_documents(trivia_entry)
return f"""
Please answer the following question, using the following documents.
Documents:
{documents}
Question:
{trivia_entry.question}
Write your answer in the json form:
{{
"answer": "your answer"
}}
Make sure your answer is just the answer in json form, with no commentary.
Start!"""
def make_prompt_old(trivia_entry):
documents = load_and_format_documents(trivia_entry)
return f"""
Please answer the following question, using the following documents.
Question:
{trivia_entry.question}
Documents:
{documents}
Write your answer in the json form:
{{
"answer": "your answer"
}}
Make sure your answer is just the answer in json form, with no commentary.
Start!"""
make_prompt = make_prompt_new
##############################################################################
# Iterate!
##############################################################################
def grade(candidate, answers):
lowered_candidate = candidate.lower()
for answer in answers:
if lowered_candidate and (lowered_candidate in answer or answer in lowered_candidate):
return True
return False
answers = {}
correctness = {}
bad_keys = []
for trivia_id in trivia_ids:
trivia_entry = trivia_entries[trivia_id]
prompt = make_prompt(trivia_entry)
resp = call_model(prompt)
# print("\n\nModel Resp:\n")
# print(resp)
answer = parse_answer(resp)
answers[trivia_id] = answer
correctness[trivia_id] = grade(answer, trivia_entry.answers)
##############################################################################
# Print Report
##############################################################################
total = len(correctness)
num_correct = len([x for x in correctness.values() if x])
num_incorrect = len([x for x in correctness.values() if not x])
accuracy = num_correct / total
performance_summary = {
'total': total,
'num_correct': num_correct,
'num_incorrect': num_incorrect,
'accuracy': f"{accuracy:.2f}"
}
incorrect_summaries = {}
for trivia_id, is_correct in correctness.items():
if not is_correct:
incorrect_summaries[trivia_id] = {
'question': trivia_entries[trivia_id].question,
'llm_answer': answers[trivia_id],
'correct_answer': trivia_entries[trivia_id].canonical_answer,
'correct_answer_aliases': list(trivia_entries[trivia_id].answers),
'provided_documents': trivia_entries[trivia_id].web_filenames
}
print(json.dumps(incorrect_summaries, indent=4))
print(json.dumps(performance_summary, indent=4))
openai==1.3.7
pydantic==2.5.2
[
"qf_1346",
"qz_77",
"qb_7836",
"jp_1823",
"sfq_16679",
"sfq_19250",
"qb_4673",
"tc_1516",
"sfq_7904",
"wh_4111",
"jp_2221",
"qw_3959",
"sfq_4360",
"odql_10794",
"bb_2238",
"sfq_10664",
"bb_7027",
"sfq_7702",
"qf_1371",
"sfq_3917",
"tc_280",
"bt_4555",
"dpql_525",
"qg_307",
"qw_2252",
"sfq_18532",
"bt_4546",
"wh_2660",
"sfq_18423",
"sfq_12016",
"sfq_18876",
"qg_2301",
"sfq_11200",
"bb_5721",
"sfq_5452",
"sfq_15080",
"sfq_1299",
"odql_351",
"qb_6583",
"sfq_18100",
"sfq_5366",
"tc_1535",
"odql_2323",
"qf_441",
"qw_2583",
"bb_1257",
"tc_2687",
"qb_9820",
"sfq_24876",
"qw_6866",
"sfq_9324",
"qz_2397",
"tc_1156",
"odql_9441",
"sfq_21166",
"sfq_11761",
"qb_7942",
"sfq_346",
"jp_478",
"sfq_23282",
"wh_1933",
"dpql_5441",
"sfq_8549",
"tc_261",
"tc_69",
"qz_4056",
"qw_1116",
"qz_2444",
"qw_2865",
"qb_4444",
"qb_6048",
"odql_7686",
"wh_2442",
"sfq_1979",
"wh_2629",
"qb_7858",
"jp_419",
"bt_1627",
"qb_9094",
"qb_2764",
"odql_6739",
"qg_4089",
"wh_2003",
"bb_7920",
"odql_3849",
"bb_4336",
"sfq_1110",
"odql_424",
"dpql_5884",
"qb_5532",
"bb_200",
"qg_3158",
"qz_5083",
"sfq_23994",
"qw_459",
"sfq_18420",
"sfq_7997",
"qg_3452",
"qb_4849",
"qz_1794"
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment