Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active April 3, 2024 04:23
Show Gist options
  • Save virattt/9f41b2cd1e8f65e672127999ad443f15 to your computer and use it in GitHub Desktop.
Save virattt/9f41b2cd1e8f65e672127999ad443f15 to your computer and use it in GitHub Desktop.
rag-reranking-gpt-colbert-mistral.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/9f41b2cd1e8f65e672127999ad443f15/rag-reranking-gpt-colbert-mistral.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Install dependencies"
],
"metadata": {
"id": "S2mGQxA958dW"
}
},
{
"cell_type": "code",
"source": [
"!pip install openai"
],
"metadata": {
"id": "2bY0NapN_z98"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lEQQJHH9gufm"
},
"outputs": [],
"source": [
"!pip install chromadb"
]
},
{
"cell_type": "code",
"source": [
"!pip install langchain"
],
"metadata": {
"id": "ygccK6lm54VT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install tiktoken"
],
"metadata": {
"id": "K5KyVC5O7Elw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install pypdf"
],
"metadata": {
"id": "_o1MOUo07GBO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "UD__5soa4iEH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Download and prepare SEC filing"
],
"metadata": {
"id": "sz639zFf6JoK"
}
},
{
"cell_type": "code",
"source": [
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n",
"sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n",
"\n",
"# Create your PDF loader\n",
"loader = PyPDFLoader(sec_filing_pdf)\n",
"\n",
"# Load the PDF document\n",
"documents = loader.load()\n",
"\n",
"# Chunk the financial report\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
],
"metadata": {
"id": "rIO5t-j7611h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Load the SEC filing into vector store"
],
"metadata": {
"id": "iaYSqxiMLUGb"
}
},
{
"cell_type": "code",
"source": [
"from langchain_community.vectorstores import Chroma\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"\n",
"# Load the document into Chroma\n",
"embedding_function = OpenAIEmbeddings()\n",
"db = Chroma.from_documents(docs, embedding_function)"
],
"metadata": {
"id": "QVZevdc-Md4N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Query the vector store"
],
"metadata": {
"id": "m8HqBNyYrDHb"
}
},
{
"cell_type": "code",
"source": [
"query = \"What are the specific factors contributing to Airbnb's increased operational expenses in the last fiscal year?\"\n",
"docs = db.similarity_search(query)"
],
"metadata": {
"id": "3qZTrAtXLPl1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Rerank using GPT-4"
],
"metadata": {
"id": "0UMU-ogKM6w8"
}
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"import time\n",
"import json\n",
"\n",
"start = time.time()\n",
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
"response = client.chat.completions.create(\n",
" model='gpt-4-1106-preview',\n",
" response_format={\"type\": \"json_object\"},\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score.\"},\n",
" {\"role\": \"user\", \"content\": f\"Query: {query} Docs: {docs}\"}\n",
" ]\n",
" )\n",
"\n",
"print(f\"Took {time.time() - start} seconds to re-rank documents with GPT-4.\")"
],
"metadata": {
"id": "Z83h16UuMlMt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Sort the scores by highest to lowest and print\n",
"scores = json.loads(response.choices[0].message.content)[\"documents\"]\n",
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n",
"print(json.dumps(sorted_data, indent=2))"
],
"metadata": {
"id": "8VZMWffzm0-i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Rerank using ColBERT"
],
"metadata": {
"id": "wnViL4XQg1FE"
}
},
{
"cell_type": "code",
"source": [
"!pip install --quiet transformers torch"
],
"metadata": {
"id": "fXPdarpEiN65"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"\n",
"# Load the tokenizer and the model\n",
"tokenizer = AutoTokenizer.from_pretrained(\"colbert-ir/colbertv2.0\")\n",
"model = AutoModel.from_pretrained(\"colbert-ir/colbertv2.0\")"
],
"metadata": {
"id": "4jE_MXfelFyv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"\n",
"start = time.time()\n",
"scores = []\n",
"\n",
"# Function to compute MaxSim\n",
"def maxsim(query_embedding, document_embedding):\n",
" # Expand dimensions for broadcasting\n",
" # Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]\n",
" # Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]\n",
" expanded_query = query_embedding.unsqueeze(2)\n",
" expanded_doc = document_embedding.unsqueeze(1)\n",
"\n",
" # Compute cosine similarity across the embedding dimension\n",
" sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)\n",
"\n",
" # Take the maximum similarity for each query token (across all document tokens)\n",
" # sim_matrix shape: [batch_size, query_length, doc_length]\n",
" max_sim_scores, _ = torch.max(sim_matrix, dim=2)\n",
"\n",
" # Average these maximum scores across all query tokens\n",
" avg_max_sim = torch.mean(max_sim_scores, dim=1)\n",
" return avg_max_sim\n",
"\n",
"# Encode the query\n",
"query_encoding = tokenizer(query, return_tensors='pt')\n",
"query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)\n",
"\n",
"# Get score for each document\n",
"for document in docs:\n",
" document_encoding = tokenizer(document.page_content, return_tensors='pt', truncation=True, max_length=512)\n",
" document_embedding = model(**document_encoding).last_hidden_state\n",
"\n",
" # Calculate MaxSim score\n",
" score = maxsim(query_embedding.unsqueeze(0), document_embedding)\n",
" scores.append({\n",
" \"score\": score.item(),\n",
" \"document\": document.page_content,\n",
" })\n",
"\n",
"print(f\"Took {time.time() - start} seconds to re-rank documents with ColBERT.\")"
],
"metadata": {
"id": "u84ePIKtjrtg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Sort the scores by highest to lowest and print\n",
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n",
"print(json.dumps(sorted_data, indent=2))"
],
"metadata": {
"id": "jBotW5sxjT6W"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Rerank using Mistral"
],
"metadata": {
"id": "dhqtmLXrUQFz"
}
},
{
"cell_type": "code",
"source": [
"!pip install mistralai"
],
"metadata": {
"id": "cYy332j3cMbt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set your Mistral API key\n",
"os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "rcPNaNTR4leC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import json\n",
"from mistralai.client import MistralClient\n",
"from mistralai.models.chat_completion import ChatMessage\n",
"\n",
"start = time.time()\n",
"client = MistralClient(api_key=os.environ[\"MISTRAL_API_KEY\"])\n",
"response = client.chat(\n",
" model=\"mistral-medium\",\n",
" messages=[\n",
" ChatMessage(role=\"system\", content=\"You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score.\"),\n",
" ChatMessage(role=\"user\", content=f\"Query: {query} Docs: {docs}\")\n",
" ]\n",
")\n",
"\n",
"print(f\"Took {time.time() - start} seconds to re-rank documents with mistral-medium.\")"
],
"metadata": {
"id": "Z3ZQalMlUUb4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"response.choices[0].message.content"
],
"metadata": {
"id": "imAL6_eqUtds"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Sort the scores by highest to lowest and print\n",
"scores = json.loads(response.choices[0].message.content)\n",
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n",
"print(json.dumps(sorted_data, indent=2))"
],
"metadata": {
"id": "SIwh8vTBUlir"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment