Created
February 24, 2024 19:55
-
-
Save virattt/530402cd8a57077ece72e88b723d9e3e to your computer and use it in GitHub Desktop.
hotpotqa-vector_db-metrics.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyN5RrhboUHDaXJ8tExxec59", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/530402cd8a57077ece72e88b723d9e3e/hotpotqa-vector_db-metrics.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Load HotpotQA data from HuggingFace\n", | |
"\n", | |
"To keep things simple, we will only load a subset of the following\n", | |
"- queries\n", | |
"- corpus\n", | |
"- qrels (relevance judgments)" | |
], | |
"metadata": { | |
"id": "mJISkfHtCyzq" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "KaoryniQ2cZj" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from datasets import load_dataset" | |
], | |
"metadata": { | |
"id": "GYwzsd962i9Q" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load the HotpotQA queries from BEIR\n", | |
"queries = load_dataset(\"BeIR/hotpotqa\", 'queries', split='queries[:10000]')" | |
], | |
"metadata": { | |
"id": "uBAAhQ3a9k_2" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load the HotpotQA corpus from BEIR\n", | |
"corpus = load_dataset(\"BeIR/hotpotqa\", 'corpus', split='corpus[:10000]')" | |
], | |
"metadata": { | |
"id": "4lKTFJCy_M7X" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load the HotpotQA qrels from BEIR\n", | |
"qrels = load_dataset(\"BeIR/hotpotqa-qrels\")" | |
], | |
"metadata": { | |
"id": "p6TQuLBGDGbD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Extract IDs from queries and corpus\n", | |
"query_ids = set(queries['_id'])\n", | |
"corpus_ids = set(corpus['_id'])\n", | |
"\n", | |
"# Filter out qrels that we do not have queries and corpus for\n", | |
"filtered_qrels = qrels.filter(lambda x: x['query-id'] in query_ids and str(x['corpus-id']) in corpus_ids)\n", | |
"print(filtered_qrels)" | |
], | |
"metadata": { | |
"id": "3LOKpfYM_cPD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Collecting unique IDs from qrels\n", | |
"unique_query_ids = set(filtered_qrels['train']['query-id'])\n", | |
"unique_corpus_ids = set(str(id) for id in filtered_qrels['train']['corpus-id'])\n", | |
"\n", | |
"# Filtering corpus and queries based on collected IDs\n", | |
"filtered_corpus = corpus.filter(lambda x: x['_id'] in unique_corpus_ids)\n", | |
"filtered_queries = queries.filter(lambda x: x['_id'] in unique_query_ids)" | |
], | |
"metadata": { | |
"id": "eD75uzpuIMbM" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Persist corpus in vector DB" | |
], | |
"metadata": { | |
"id": "vfzSguBuCwEa" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install langchain langchain_openai chromadb" | |
], | |
"metadata": { | |
"id": "7VL4g11lCbPw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" | |
], | |
"metadata": { | |
"id": "td8Z72ijEUMV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain_community.vectorstores import Chroma\n", | |
"from langchain_core.documents import Document\n", | |
"from langchain_openai import OpenAIEmbeddings\n", | |
"\n", | |
"embeddings = OpenAIEmbeddings()" | |
], | |
"metadata": { | |
"id": "06SBUyiVEeTd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Create Documents\n", | |
"documents = []\n", | |
"for corpus in filtered_corpus:\n", | |
" documents.append(\n", | |
" Document(\n", | |
" page_content=corpus['text'],\n", | |
" metadata={'title': corpus['title'], 'id': corpus['_id']},\n", | |
" )\n", | |
" )" | |
], | |
"metadata": { | |
"id": "UGJ4XYbCDMKJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Store documents in vector DB\n", | |
"vectorstore = Chroma.from_documents(documents, embeddings)" | |
], | |
"metadata": { | |
"id": "WQPekVUSEwk1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Query the vector DB and measure performance of retrieval" | |
], | |
"metadata": { | |
"id": "uvSjm_a4FkZD" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"'''\n", | |
"Create a simplified dictionary of qrels where:\n", | |
"key: query-id\n", | |
"value: list of corpus-id\n", | |
"'''\n", | |
"qrels_dict = {}\n", | |
"\n", | |
"# Loop through each qrel in filtered_qrels\n", | |
"for qrel in filtered_qrels['train']:\n", | |
" query_id = qrel['query-id'] # Extract the query-id from qrel\n", | |
"\n", | |
" # Initialize the list of corpus-ids for the current query-id in qrels_dict\n", | |
" if query_id not in qrels_dict:\n", | |
" qrels_dict[query_id] = []\n", | |
"\n", | |
" # Loop through filtered_qrels again to find all matches for the current query-id\n", | |
" for x in filtered_qrels['train']:\n", | |
" if x['query-id'] == query_id:\n", | |
" # Add the corpus-id to the list for the current query-id, ensuring it's a string\n", | |
" qrels_dict[query_id].append(str(x['corpus-id']))\n" | |
], | |
"metadata": { | |
"id": "_CuRPpLXKys0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def compute_recall(retrieved_docs, relevant_docs):\n", | |
" if not relevant_docs:\n", | |
" return 0 # Avoid division by zero if there are no relevant docs\n", | |
" relevant_count = sum(1 for doc in retrieved_docs if doc.page_content in relevant_docs)\n", | |
" return relevant_count / len(relevant_docs)\n", | |
"\n", | |
"def compute_ndcg(retrieved_docs, relevant_docs, k=5):\n", | |
" def dcg(retrieved_docs, relevant_docs):\n", | |
" # DCG is the sum of the relevance scores (logarithmically discounted)\n", | |
" return sum((doc.page_content in relevant_docs) / np.log2(i + 2) for i, doc in enumerate(retrieved_docs[:k]))\n", | |
"\n", | |
" def idcg(relevant_docs):\n", | |
" # iDCG is the DCG of the ideal ranking (all relevant documents at the top)\n", | |
" return sum(1 / np.log2(i + 2) for i in range(min(len(relevant_docs), k)))\n", | |
"\n", | |
" return dcg(retrieved_docs, relevant_docs) / idcg(relevant_docs)\n", | |
"\n", | |
"def compute_mrr(retrieved_docs, relevant_docs, k=5):\n", | |
" for i, doc in enumerate(retrieved_docs[:k], start=1):\n", | |
" if doc.page_content in relevant_docs:\n", | |
" return 1 / i\n", | |
" return 0 # Return 0 if no relevant document is found in the top k" | |
], | |
"metadata": { | |
"id": "o3E_KLbaUiF8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"# Assuming you have these loaded:\n", | |
"# queries - a dictionary or list of query texts where keys are query IDs\n", | |
"# qrels - a dictionary where keys are query IDs and values are lists of relevant document IDs\n", | |
"# vector_db - an object representing your vector database with a method `search(query_text, top_k)` that returns the top_k document IDs for a given query\n", | |
"\n", | |
"k = 5\n", | |
"\n", | |
"# Store metrics for each query\n", | |
"recall_values = []\n", | |
"ndcg_values = []\n", | |
"mrr_values = []\n", | |
"\n", | |
"for index, query in enumerate(filtered_queries):\n", | |
" question = query['text']\n", | |
" question_id = query['_id']\n", | |
"\n", | |
" # Perform the search in the vector DB for the current query\n", | |
" top_k_docs = vectorstore.similarity_search(question, k)\n", | |
"\n", | |
" # Get relevant doc IDs for the current query\n", | |
" relevant_doc_ids = qrels_dict.get(question_id, [])\n", | |
"\n", | |
" # Fetch the texts of relevant docs\n", | |
" relevant_docs = [doc['text'] for doc in filtered_corpus if doc['_id'] in relevant_doc_ids]\n", | |
"\n", | |
" # Compute metrics for the current query\n", | |
" recall_score = compute_recall(top_k_docs, relevant_docs)\n", | |
" ndcg_score = compute_ndcg(top_k_docs, relevant_docs)\n", | |
" mrr_score = compute_mrr(top_k_docs, relevant_docs)\n", | |
"\n", | |
" # Append the metrics for this query to the lists\n", | |
" recall_values.append(recall_score)\n", | |
" ndcg_values.append(ndcg_score)\n", | |
" mrr_values.append(mrr_score)" | |
], | |
"metadata": { | |
"id": "iEdGJtncNuzj" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Compute the final metric values\n", | |
"final_recall = np.mean(recall_values)\n", | |
"final_ndcg = np.mean(ndcg_values)\n", | |
"final_mrr = np.mean(mrr_values)\n", | |
"\n", | |
"print(f\"Final Recall@{k}: {round(final_recall * 100, 2)}\")\n", | |
"print(f\"Final nDCG@{k}: {round(final_ndcg * 100, 2)}\")\n", | |
"print(f\"Final MRR@{k}: {round(final_mrr * 100, 2)}\")" | |
], | |
"metadata": { | |
"id": "_1Cv9bmTSTLf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment