Created
February 28, 2023 08:57
-
-
Save xkisu/dc12a2030b0886355fc42fbbd27be1ae to your computer and use it in GitHub Desktop.
Query a PDF outline for info with GPT3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from pypdf import PdfReader | |
import os.path | |
import camelot | |
import pandas as pd | |
import openai | |
import numpy as np | |
import pickle | |
import tiktoken | |
from transformers import GPT2TokenizerFast | |
logging.basicConfig(level=logging.INFO) | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
def process_outline(outline, dfs=[]): | |
if outline and isinstance(outline, list) and len(outline) > 0: | |
for outline_child in outline: | |
process_outline(outline_child, dfs) | |
else: | |
title = outline['/Title'] # Get the title of the outline item | |
page_num = reader.get_destination_page_number(outline) | |
# print(f"page {page_num}: {title}") | |
# TODO: count tokens for content len(tokens['input_ids']) | |
# tokens = tokenizer(title) | |
# Extract the contents of the page the outline item links to | |
page = reader.pages[page_num] | |
page_text = page.extract_text() | |
dfs.append( | |
pd.DataFrame({'title': title, 'content': title, 'tokens': len(title), 'page': page_num}, index=['title'])) | |
return dfs | |
# https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb | |
COMPLETIONS_MODEL = "text-davinci-003" | |
EMBEDDING_MODEL = "text-embedding-ada-002" | |
def get_embedding(text: str, model: str = EMBEDDING_MODEL) -> list[float]: | |
result = openai.Embedding.create( | |
model=model, | |
input=text | |
) | |
return result["data"][0]["embedding"] | |
def compute_doc_embeddings(df: pd.DataFrame) -> dict[tuple[str], list[float]]: | |
""" | |
Create an embedding for each row in the dataframe using the OpenAI Embeddings API. | |
Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to. | |
""" | |
return { | |
idx: get_embedding(r.content) for idx, r in df.iterrows() | |
} | |
def load_embeddings(fname: str) -> dict[tuple[str], list[float]]: | |
""" | |
Read the document embeddings and their keys from a CSV. | |
fname is the path to a CSV with exactly these named columns: | |
"title", "0", "1", ... up to the length of the embedding vectors. | |
""" | |
df = pd.read_csv(fname, header=0) | |
# https://github.com/openai/openai-cookbook/issues/137#issuecomment-1434838657 | |
df = df.set_axis(['title'] + list(df.columns[1:]), axis=1) | |
max_dim = max([int(c) for c in df.columns if c != "title"]) | |
return { | |
r.title: [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows() | |
} | |
def vector_similarity(x: list[float], y: list[float]) -> float: | |
""" | |
Returns the similarity between two vectors. | |
Because OpenAI Embeddings are normalized to length 1, the cosine similarity is the same as the dot product. | |
""" | |
return np.dot(np.array(x), np.array(y)) | |
def order_document_sections_by_query_similarity(query: str, contexts: dict[(str, str), np.array]) -> list[ | |
(float, (str, str))]: | |
""" | |
Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddings | |
to find the most relevant sections. | |
Return the list of document sections, sorted by relevance in descending order. | |
""" | |
query_embedding = get_embedding(query) | |
document_similarities = sorted([ | |
(vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items() | |
], reverse=True) | |
return document_similarities | |
MAX_SECTION_LEN = 500 | |
SEPARATOR = "\n* " | |
ENCODING = "gpt2" # encoding for text-davinci-003 | |
encoding = tiktoken.get_encoding(ENCODING) | |
separator_len = len(encoding.encode(SEPARATOR)) | |
def construct_prompt(question: str, context_embeddings: dict, df: pd.DataFrame) -> str: | |
""" | |
Fetch relevant | |
""" | |
most_relevant_document_sections = order_document_sections_by_query_similarity(question, context_embeddings) | |
chosen_sections = [] | |
chosen_sections_len = 0 | |
chosen_sections_indexes = [] | |
for _, section_index in most_relevant_document_sections: | |
# Add contexts until we run out of space. | |
document_section = df.loc[section_index] | |
chosen_sections_len += document_section.tokens + separator_len | |
if chosen_sections_len > MAX_SECTION_LEN: | |
break | |
chosen_sections.append(SEPARATOR + document_section.content.replace("\n", " ")) | |
chosen_sections_indexes.append(str(section_index)) | |
# Useful diagnostic information | |
print(f"Selected {len(chosen_sections)} document sections:") | |
print("\n".join(chosen_sections_indexes)) | |
header = """Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say "I don't know."\n\nContext:\n""" | |
return header + "".join(chosen_sections) + "\n\n Q: " + question + "\n A:" | |
COMPLETIONS_API_PARAMS = { | |
# We use temperature of 0.0 because it gives the most predictable, factual answer. | |
"temperature": 0.0, | |
"max_tokens": 300, | |
"model": COMPLETIONS_MODEL, | |
} | |
def answer_query_with_context( | |
query: str, | |
df: pd.DataFrame, | |
document_embeddings: dict[(str, str), np.array], | |
show_prompt: bool = False | |
) -> str: | |
prompt = construct_prompt( | |
query, | |
document_embeddings, | |
df | |
) | |
if show_prompt: | |
print(prompt) | |
response = openai.Completion.create( | |
prompt=prompt, | |
**COMPLETIONS_API_PARAMS | |
) | |
return response["choices"][0]["text"].strip(" \n") | |
# Define path to PDF file | |
pdf_file = "stm32f205rb.pdf" | |
# Extract text from the PDF | |
logging.info('reading pdf contents') | |
reader = PdfReader(pdf_file) | |
# Get the page count of the PDF | |
page_count = len(reader.pages) | |
outline_dfs = [] | |
# Load the data if cached, otherwise generate them. | |
if os.path.isfile('data.csv'): | |
logging.info(f'loading cached outline data') | |
outline_df = pd.read_csv("data.csv") | |
outline_df = outline_df.set_index(["title"]) | |
else: | |
# https://github.com/openai/openai-cookbook/blob/2f5e350bbe66a418184899b0e12f182dbb46a156/examples/fine-tuned_qa/olympics-2-create-qa.ipynb | |
logging.info(f'generating outline data') | |
for i, o in enumerate(reader.outline): | |
process_outline(o, outline_dfs) | |
outline_df = pd.concat(outline_dfs, ignore_index=True) | |
outline_df = outline_df.set_index(["title"]) | |
outline_df = outline_df.sort_values(by=['page']) | |
outline_df.to_csv("data.csv", index=True) | |
# outline_df.to_csv("data.csv", index=False) | |
# Remove newlines and commas from the page | |
# content so that we can process it easier. | |
# outline_df.loc[:, "content"] = outline_df["content"].apply(lambda x : x.replace('\n', '').replace(',', ' ')) | |
# document_embeddings = load_embeddings("https://cdn.openai.com/API/examples/data/olympics_sections_document_embeddings.csv") | |
# Calculate embeddings for the PDF's outline. | |
# | |
# This includes the table of contents, figure | |
# references, and any table references and | |
# the titles of those figures and tables. | |
# | |
# TODO: use tokenizer to extract relevant text | |
# https://github.com/openai/openai-cookbook/blob/2f5e350bbe66a418184899b0e12f182dbb46a156/examples/fine-tuned_qa/olympics-1-collect-data.ipynb | |
if os.path.isfile('embeddings.csv'): | |
logging.info(f'loading outline embeddings') | |
outline_embeddings = load_embeddings("embeddings.csv") | |
else: | |
logging.info(f'creating outline embeddings') | |
outline_embeddings = compute_doc_embeddings(outline_df) | |
# https://github.com/openai/openai-cookbook/issues/137 | |
pd.DataFrame(outline_embeddings).T.to_csv("embeddings.csv") | |
# outline_embeddings.to_csv("embeddings.csv", header=0) | |
# prompt = construct_prompt( | |
# "Does the STM32F20x have a LQFP100 layout?", | |
# outline_embeddings, | |
# outline_df | |
# ) | |
# | |
# print("===\n", prompt) | |
logging.info(f'running query') | |
# query = "Does the STM32F20x have a LQFP100 layout?" | |
query = "test" | |
while query != "": | |
query = input("Question? ") | |
answer = answer_query_with_context(query, outline_df, outline_embeddings, show_prompt=True) | |
print(f"\nQ: {query}\nA: {answer}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment