Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active March 20, 2024 07:54
Show Gist options
  • Save virattt/cfbec216766834e6092bce8d099aee25 to your computer and use it in GitHub Desktop.
Save virattt/cfbec216766834e6092bce8d099aee25 to your computer and use it in GitHub Desktop.
query-rewriting-claude-gpt.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMHwD/0P0Ek81UQwWlP4hoM",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/cfbec216766834e6092bce8d099aee25/query-rewriting-claude-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QXiI_u16ak7W"
},
"outputs": [],
"source": [
"pip install anthropic openai"
]
},
{
"cell_type": "code",
"source": [
"query = \"What's going on with Airbnb's numbers?\"\n",
"prompt = \"\"\"\n",
"You are an AI trained to improve the effectiveness of information retrieval from a vector database.\n",
"Your task is to rewrite queries from their original form into a more detailed version that includes additional context or clarification.\n",
"This helps in enhancing the accuracy of search results in terms of NDCG, Recall@K, and MRR@K metrics.\n",
"Below are a few examples of how queries can be rewritten for better understanding and retrieval performance.\n",
"\n",
"Example 1:\n",
"Original Query: \"Who is the president mentioned in relation to the healthcare law?\"\n",
"Rewritten Query: \"Who is the U.S. president mentioned in discussions about the Affordable Care Act (Obamacare)?\"\n",
"\n",
"Example 2:\n",
"Original Query: \"What film did both actors star in?\"\n",
"Rewritten Query: \"In which film did Leonardo DiCaprio and Kate Winslet both have starring roles?\"\n",
"\n",
"Example 3:\n",
"Original Query: \"When was the company founded that created the iPhone?\"\n",
"Rewritten Query: \"What is the founding year of Apple Inc., the company that developed the iPhone?\"\n",
"\n",
"Given a user query, rewrite it to add more context and clarity, similar to the examples provided above.\n",
"Ensure the rewritten query is specific and detailed to improve the retrieval of relevant information from a database.\n",
"\"\"\""
],
"metadata": {
"id": "oihhnP75cpe7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Compute query rewriting speed of Claude Instant"
],
"metadata": {
"id": "u2rZNrR0bqKf"
}
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your Anthropic API key\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "z1LhTFxnbICa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import time\n",
"from anthropic import Anthropic\n",
"\n",
"client = Anthropic(api_key=os.environ.get(\"ANTHROPIC_API_KEY\"))\n",
"\n",
"total_time = 0.0\n",
"total_cost = 0.0\n",
"num_iterations = 100\n",
"rewritten_queries = []\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" claude_response = client.messages.create(\n",
" model=\"claude-instant-1.2\",\n",
" temperature=0.0,\n",
" max_tokens=100,\n",
" system=prompt,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": query},\n",
" ],\n",
" )\n",
"\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Update total cost\n",
" rewritten_queries.append(claude_response.content)\n",
" input_cost = claude_response.usage.input_tokens * 0.0000008\n",
" output_cost = claude_response.usage.output_tokens * 0.0000024\n",
" total_cost += (input_cost + output_cost)\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"avg_cost = total_cost / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query and cost ${format(avg_cost, '.8f')}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wrfj19ctbNcB",
"outputId": "9b4fefea-f6fd-4f74-f490-f6edd44b24be"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Took 1.7971826553344727 seconds to rewrite the query and cost $0.00042720\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Print the rewritten queries\n",
"for rewritten_query in rewritten_queries[:5]:\n",
" print(rewritten_query)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8_lyuDW-ed5t",
"outputId": "2490b282-3e5b-4311-f65e-2001eb939235"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[ContentBlock(text='Here is a rewritten version of the query to provide more context:\\n\\n\"What have been Airbnb\\'s recent financial results and business metrics? Specifically, I\\'m looking for information on Airbnb\\'s revenue, earnings, bookings numbers, or other financial or operational data that would provide insight into the company\\'s current performance and how its business metrics have been trending recently.\"', type='text')]\n",
"[ContentBlock(text='Here is a rewritten version of the query to provide more context:\\n\\n\"What have been Airbnb\\'s recent financial results and business metrics? Specifically, I\\'m looking for information on Airbnb\\'s revenue, earnings, bookings numbers, or other financial or operational data that would provide insight into the company\\'s current performance and how its business metrics have been trending recently.\"', type='text')]\n",
"[ContentBlock(text='Here is a rewritten version of the query to provide more context:\\n\\n\"What have been Airbnb\\'s recent financial results and business metrics? Specifically, I\\'m looking for information on Airbnb\\'s revenue, earnings, bookings numbers, or other financial or operational data that would provide insight into the company\\'s current performance and how its business metrics have been trending recently.\"', type='text')]\n",
"[ContentBlock(text='Here is a rewritten version of the query to provide more context:\\n\\n\"What have been Airbnb\\'s recent financial results and business metrics? Specifically, I\\'m looking for information on Airbnb\\'s revenue, earnings, bookings numbers, or other financial or operational data that would provide insight into the company\\'s current performance and how its business metrics have been trending recently.\"', type='text')]\n",
"[ContentBlock(text='Here is a rewritten version of the query to provide more context:\\n\\n\"What have been Airbnb\\'s recent financial results and business metrics? Specifically, I\\'m looking for information on Airbnb\\'s revenue, earnings, bookings numbers, or other financial or operational data that would provide insight into the company\\'s current performance and how its business metrics have been trending recently.\"', type='text')]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Compute query rewriting speed of GPT-3.5 Turbo"
],
"metadata": {
"id": "pHaZZoyobiVN"
}
},
{
"cell_type": "code",
"source": [
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "7lixGQbhcWvx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"import time\n",
"import json\n",
"\n",
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
"\n",
"total_time = 0.0\n",
"total_cost = 0.0\n",
"num_iterations = 100\n",
"rewritten_queries = []\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" gpt_response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo-0125',\n",
" temperature=0.0,\n",
" max_tokens=100,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
" )\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Update total cost\n",
" rewritten_query = gpt_response.choices[0].message.content\n",
" rewritten_queries.append(rewritten_query)\n",
" input_cost = gpt_response.usage.prompt_tokens * 0.0000005\n",
" output_cost = gpt_response.usage.completion_tokens * 0.0000015\n",
" total_cost += (input_cost + output_cost)\n",
"\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"avg_cost = total_cost / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query and cost ${format(avg_cost, '.8f')}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QvMK9nh6bkMp",
"outputId": "500af3d5-b3dd-4d90-c23b-e8ce4a7a814f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Took 0.6536604881286621 seconds to rewrite the query and cost $0.00006300\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Print the rewritten queries\n",
"for rewritten_query in rewritten_queries[:5]:\n",
" print(rewritten_query)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pTSNvcYkfuVU",
"outputId": "7dc9ea8f-2a03-49b4-e186-974b35a873b3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Please provide an analysis of the current financial performance and key metrics of Airbnb.\n",
"Please provide an analysis of the current financial performance and key metrics of Airbnb.\n",
"Please provide an analysis of the current financial performance and key metrics of Airbnb.\n",
"Please provide an analysis of the current financial performance and key metrics of Airbnb.\n",
"Please provide an analysis of the current financial performance and key metrics of Airbnb.\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment