Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active February 27, 2024 14:04
Show Gist options
  • Save virattt/2b024d70a7dc39c2504f0235aa2fd2b8 to your computer and use it in GitHub Desktop.
Save virattt/2b024d70a7dc39c2504f0235aa2fd2b8 to your computer and use it in GitHub Desktop.
gpt-3-5-few-shot-query-rewriting.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100",
"authorship_tag": "ABX9TyN1fpUfmXg34dJpJFj7/JQw",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/2b024d70a7dc39c2504f0235aa2fd2b8/gpt-3-5-few-shot-query-rewriting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Load HotpotQA data from HuggingFace\n",
"\n",
"To keep things simple, we will only load a subset of the following\n",
"- queries\n",
"- corpus\n",
"- qrels (relevance judgments)"
],
"metadata": {
"id": "mJISkfHtCyzq"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KaoryniQ2cZj"
},
"outputs": [],
"source": [
"!pip install datasets"
]
},
{
"cell_type": "code",
"source": [
"from datasets import load_dataset"
],
"metadata": {
"id": "GYwzsd962i9Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load the HotpotQA queries from BEIR\n",
"queries = load_dataset(\"BeIR/hotpotqa\", 'queries', split='queries')"
],
"metadata": {
"id": "uBAAhQ3a9k_2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load a subset of the HotpotQA corpus from BEIR\n",
"corpus = load_dataset(\"BeIR/hotpotqa\", 'corpus', split='corpus[:100000]')"
],
"metadata": {
"id": "4lKTFJCy_M7X"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load the HotpotQA qrels from BEIR\n",
"qrels = load_dataset(\"BeIR/hotpotqa-qrels\")"
],
"metadata": {
"id": "p6TQuLBGDGbD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Extract IDs from queries and corpus\n",
"query_ids = set(queries['_id'])\n",
"corpus_ids = set(corpus['_id'])\n",
"\n",
"# Filter out qrels that we do not have queries and corpus for\n",
"filtered_qrels = qrels.filter(lambda x: x['query-id'] in query_ids and str(x['corpus-id']) in corpus_ids)\n",
"print(filtered_qrels)"
],
"metadata": {
"id": "3LOKpfYM_cPD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Select the first 1000 indices\n",
"indices = list(range(1000))\n",
"\n",
"# Keep only the first 1000 rows in the train dataset\n",
"filtered_qrels['train'] = filtered_qrels['train'].select(indices)"
],
"metadata": {
"id": "EJIsP35tZ6Zk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(filtered_qrels)"
],
"metadata": {
"id": "fV52LkB1aI4n"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Collecting unique IDs from qrels\n",
"unique_query_ids = set(filtered_qrels['train']['query-id'])\n",
"unique_corpus_ids = set(str(id) for id in filtered_qrels['train']['corpus-id'])\n",
"\n",
"# Filtering corpus and queries based on collected IDs\n",
"filtered_corpus = corpus.filter(lambda x: x['_id'] in unique_corpus_ids)\n",
"filtered_queries = queries.filter(lambda x: x['_id'] in unique_query_ids)"
],
"metadata": {
"id": "eD75uzpuIMbM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Persist corpus in vector DB"
],
"metadata": {
"id": "vfzSguBuCwEa"
}
},
{
"cell_type": "code",
"source": [
"!pip install langchain langchain_openai chromadb"
],
"metadata": {
"id": "7VL4g11lCbPw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
],
"metadata": {
"id": "td8Z72ijEUMV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from langchain_community.vectorstores import Chroma\n",
"from langchain_core.documents import Document\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"embeddings = OpenAIEmbeddings()"
],
"metadata": {
"id": "06SBUyiVEeTd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Create Documents\n",
"documents = []\n",
"for corpus in filtered_corpus:\n",
" documents.append(\n",
" Document(\n",
" page_content=corpus['text'],\n",
" metadata={'title': corpus['title'], 'id': corpus['_id']},\n",
" )\n",
" )"
],
"metadata": {
"id": "UGJ4XYbCDMKJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Store documents in vector DB\n",
"vectorstore = Chroma.from_documents(documents, embeddings)"
],
"metadata": {
"id": "WQPekVUSEwk1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Query rewriting with GPT"
],
"metadata": {
"id": "TsRKQ3ebUmiZ"
}
},
{
"cell_type": "code",
"source": [
"# Helper function to get num tokens in a phrase\n",
"def get_num_tokens(phrase: str) -> float:\n",
" words = phrase.split()\n",
" word_count = len(words)\n",
"\n",
" # Multiplying the number of words by 1.3 to get the total number of tokens\n",
" return word_count * 1.3"
],
"metadata": {
"id": "ne4Oa1FWWuhg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"prompt = \"\"\"\n",
"You are an AI trained to improve the effectiveness of information retrieval from a vector database.\n",
"Your task is to rewrite queries from their original form into a more detailed version that includes additional context or clarification.\n",
"This helps in enhancing the accuracy of search results in terms of NDCG, Recall@K, and MRR@K metrics.\n",
"Below are a few examples of how queries can be rewritten for better understanding and retrieval performance.\n",
"\n",
"Example 1:\n",
"Original Query: \"Who is the president mentioned in relation to the healthcare law?\"\n",
"Rewritten Query: \"Who is the U.S. president mentioned in discussions about the Affordable Care Act (Obamacare)?\"\n",
"\n",
"Example 2:\n",
"Original Query: \"What film did both actors star in?\"\n",
"Rewritten Query: \"In which film did Leonardo DiCaprio and Kate Winslet both have starring roles?\"\n",
"\n",
"Example 3:\n",
"Original Query: \"When was the company founded that created the iPhone?\"\n",
"Rewritten Query: \"What is the founding year of Apple Inc., the company that developed the iPhone?\"\n",
"\n",
"Given a user query, rewrite it to add more context and clarity, similar to the examples provided above.\n",
"Ensure the rewritten query is specific and detailed to improve the retrieval of relevant information from a database.\n",
"\"\"\""
],
"metadata": {
"id": "5sgI49cAX1bJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"\n",
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
"\n",
"def rewrite_query_openai(query: str, prompt: str, model: str) -> str:\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0,\n",
" seed=42,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
" )\n",
" # Get the rewritten query from response\n",
" rewritten_query = response.choices[0].message.content\n",
" return rewritten_query"
],
"metadata": {
"id": "_jXjwG97UjHq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Compute query rewriting performance metrics"
],
"metadata": {
"id": "v-Zts5PdT-6T"
}
},
{
"cell_type": "code",
"source": [
"'''\n",
"Map the HugginFace DataSet into a simple python dict so that we can loop through it more quickly:\n",
"key: query-id\n",
"value: list of corpus-id\n",
"'''\n",
"qrels_dict = {}\n",
"\n",
"# Loop through each qrel in filtered_qrels\n",
"for qrel in filtered_qrels['train']:\n",
" query_id = qrel['query-id'] # Extract the query-id from qrel\n",
"\n",
" # Initialize the list of corpus-ids for the current query-id in qrels_dict\n",
" if query_id not in qrels_dict:\n",
" qrels_dict[query_id] = []\n",
"\n",
" # Loop through filtered_qrels again to find all matches for the current query-id\n",
" for x in filtered_qrels['train']:\n",
" if x['query-id'] == query_id:\n",
" # Add the corpus-id to the list for the current query-id, ensuring it's a string\n",
" qrels_dict[query_id].append(str(x['corpus-id']))\n"
],
"metadata": {
"id": "_CuRPpLXKys0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def compute_recall(retrieved_docs, relevant_docs):\n",
" if not relevant_docs:\n",
" return 0 # Avoid division by zero if there are no relevant docs\n",
" relevant_count = sum(1 for doc in retrieved_docs if doc.page_content in relevant_docs)\n",
" return relevant_count / len(relevant_docs)\n",
"\n",
"def compute_ndcg(retrieved_docs, relevant_docs, k=5):\n",
" def dcg(retrieved_docs, relevant_docs):\n",
" # DCG is the sum of the relevance scores (logarithmically discounted)\n",
" return sum((doc.page_content in relevant_docs) / np.log2(i + 2) for i, doc in enumerate(retrieved_docs[:k]))\n",
"\n",
" def idcg(relevant_docs):\n",
" # iDCG is the DCG of the ideal ranking (all relevant documents at the top)\n",
" return sum(1 / np.log2(i + 2) for i in range(min(len(relevant_docs), k)))\n",
"\n",
" return dcg(retrieved_docs, relevant_docs) / idcg(relevant_docs)\n",
"\n",
"def compute_mrr(retrieved_docs, relevant_docs, k=5):\n",
" for i, doc in enumerate(retrieved_docs[:k], start=1):\n",
" if doc.page_content in relevant_docs:\n",
" return 1 / i\n",
" return 0 # Return 0 if no relevant document is found in the top k"
],
"metadata": {
"id": "o3E_KLbaUiF8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import time\n",
"\n",
"# Assuming you have these loaded:\n",
"# queries - a dictionary or list of query texts where keys are query IDs\n",
"# qrels - a dictionary where keys are query IDs and values are lists of relevant document IDs\n",
"# vector_db - an object representing your vector database with a method `search(query_text, top_k)` that returns the top_k document IDs for a given query\n",
"\n",
"k = 3\n",
"max_iterations = 100\n",
"\n",
"# Store metrics for each query\n",
"recall_values = []\n",
"ndcg_values = []\n",
"mrr_values = []\n",
"\n",
"for index, query in enumerate(filtered_queries):\n",
" if index == max_iterations:\n",
" break\n",
"\n",
" question = query['text']\n",
" question_id = query['_id']\n",
"\n",
" # Rewrite the query with OpenAI\n",
" rewritten_query = rewrite_query_openai(\n",
" query=question,\n",
" prompt=prompt,\n",
" model='gpt-3.5-turbo-0125',\n",
" )\n",
"\n",
" print(f\"Question {index + 1}\")\n",
" print(f\"Original query: {question}\")\n",
" print(f\"Rewritten query: {rewritten_query}\")\n",
" print()\n",
"\n",
" # Perform the search in the vector DB for the current query\n",
" top_k_docs = vectorstore.similarity_search(rewritten_query, k)\n",
"\n",
" # Get relevant doc IDs for the current query\n",
" relevant_doc_ids = qrels_dict.get(question_id, [])\n",
"\n",
" # Fetch the texts of relevant docs\n",
" relevant_docs = [doc['text'] for doc in filtered_corpus if doc['_id'] in relevant_doc_ids]\n",
"\n",
" # Compute metrics for the current query\n",
" recall_score = compute_recall(top_k_docs, relevant_docs)\n",
" ndcg_score = compute_ndcg(top_k_docs, relevant_docs)\n",
" mrr_score = compute_mrr(top_k_docs, relevant_docs)\n",
"\n",
" # Append the metrics for this query to the lists\n",
" recall_values.append(recall_score)\n",
" ndcg_values.append(ndcg_score)\n",
" mrr_values.append(mrr_score)\n",
"\n",
" # Wait for 1 seconds before the next iteration to avoid rate limiting\n",
" # time.sleep(1)"
],
"metadata": {
"id": "iEdGJtncNuzj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute the final metric values\n",
"final_recall = np.mean(recall_values)\n",
"final_ndcg = np.mean(ndcg_values)\n",
"final_mrr = np.mean(mrr_values)\n",
"\n",
"print(f\"Final Recall@{k}: {round(final_recall * 100, 2)}\")\n",
"print(f\"Final nDCG@{k}: {round(final_ndcg * 100, 2)}\")\n",
"print(f\"Final MRR@{k}: {round(final_mrr * 100, 2)}\")"
],
"metadata": {
"id": "_1Cv9bmTSTLf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "ZLT625mf0gUk"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment