Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active February 5, 2024 00:56
Show Gist options
  • Save virattt/7aee148b89b935aa0e7be03d60d72707 to your computer and use it in GitHub Desktop.
Save virattt/7aee148b89b935aa0e7be03d60d72707 to your computer and use it in GitHub Desktop.
cost-query-rewriting-gpt-mistral-cohere.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/7aee148b89b935aa0e7be03d60d72707/cost-query-rewriting-gpt-mistral-cohere.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Create Query and Prompt"
],
"metadata": {
"id": "m8HqBNyYrDHb"
}
},
{
"cell_type": "code",
"source": [
"query = \"What's going on with Airbnb's numbers?\"\n",
"\n",
"prompt = \"\"\"\n",
"Rewrite the following user query into a clear, specific, and\n",
"formal request suitable for retrieving relevant information from a vector database.\n",
"Keep in mind that your rewritten query will be sent to a vector database, which\n",
"does similarity search for retrieving documents. Your output must be 100 tokens max.\n",
"\"\"\""
],
"metadata": {
"id": "3qZTrAtXLPl1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Helper function to get num tokens in a phrase\n",
"def get_num_tokens(phrase: str) -> float:\n",
" words = phrase.split()\n",
" word_count = len(words)\n",
"\n",
" # Multiplying the number of words by 1.3 to get the total number of tokens\n",
" return word_count * 1.3"
],
"metadata": {
"id": "brMdLqEbetEC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install openai"
],
"metadata": {
"id": "2bY0NapN_z98"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "tavToGb_MJrc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use GPT-4 to rewrite the query"
],
"metadata": {
"id": "bPzoWQhVAmLt"
}
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"import time\n",
"import json\n",
"\n",
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
"\n",
"total_time = 0.0\n",
"total_cost = 0.0\n",
"num_iterations = 10\n",
"response = None\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" response = client.chat.completions.create(\n",
" model='gpt-4-0125-preview',\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
" )\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Update total cost\n",
" rewritten_query = response.choices[0].message.content\n",
" input_cost = get_num_tokens(phrase=query) * 0.00001\n",
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.00003\n",
" total_cost += (input_cost + output_cost)\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query.\")"
],
"metadata": {
"id": "Z83h16UuMlMt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_query = response.choices[0].message.content\n",
"print(rewritten_query)"
],
"metadata": {
"id": "8VZMWffzm0-i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the cost\n",
"print(f\"Total cost of rewriting query: ${total_cost}\")"
],
"metadata": {
"id": "8KaBnljJeUan"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use GPT-3.5 to rewrite the query"
],
"metadata": {
"id": "AdrLmbzAAsgX"
}
},
{
"cell_type": "code",
"source": [
"total_time = 0.0\n",
"total_cost = 0.0\n",
"num_iterations = 10\n",
"response = None\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo-0125',\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
" )\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Update total cost\n",
" rewritten_query = response.choices[0].message.content\n",
" input_cost = get_num_tokens(phrase=query) * 0.0000005\n",
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.0000015\n",
" total_cost += (input_cost + output_cost)\n",
"\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query.\")"
],
"metadata": {
"id": "AdpynLvNAvww"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_query = response.choices[0].message.content\n",
"print(rewritten_query)"
],
"metadata": {
"id": "yFKleKcWD84S"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the cost\n",
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")"
],
"metadata": {
"id": "a2IgO3Fvftus"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use Mistral to rewrite the query"
],
"metadata": {
"id": "FMHkITrr-ru2"
}
},
{
"cell_type": "code",
"source": [
"!pip install mistralai"
],
"metadata": {
"id": "cYy332j3cMbt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set your Mistral API key\n",
"os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "rcPNaNTR4leC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from mistralai.client import MistralClient\n",
"from mistralai.models.chat_completion import ChatMessage\n",
"\n",
"client = MistralClient(api_key=os.environ[\"MISTRAL_API_KEY\"])\n",
"\n",
"total_time = 0.0\n",
"total_cost = 0.0\n",
"num_iterations = 10\n",
"response = None\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" response = client.chat(\n",
" model=\"mistral-medium\",\n",
" messages=[\n",
" ChatMessage(role=\"system\", content=prompt),\n",
" ChatMessage(role=\"user\", content=query)\n",
" ]\n",
" )\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Update total cost\n",
" rewritten_query = response.choices[0].message.content\n",
" input_cost = get_num_tokens(phrase=query) * 0.0000027\n",
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.0000081\n",
" total_cost += (input_cost + output_cost)\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query.\")"
],
"metadata": {
"id": "Z3ZQalMlUUb4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_query = response.choices[0].message.content\n",
"print(rewritten_query)"
],
"metadata": {
"id": "imAL6_eqUtds"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the cost\n",
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")"
],
"metadata": {
"id": "HFGVbymlhiuE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use Cohere to rewrite the query"
],
"metadata": {
"id": "DrgoAKInw-z8"
}
},
{
"cell_type": "code",
"source": [
"!pip install cohere"
],
"metadata": {
"id": "uzEp-k7Xxfla"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set your Cohere API key\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "V7X3rjrb4uAX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import cohere\n",
"\n",
"# Get your cohere API key on: www.cohere.com\n",
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n",
"\n",
"total_time = 0\n",
"total_cost = 0.0\n",
"num_iterations = 10\n",
"response = None\n",
"\n",
"for _ in range(num_iterations):\n",
" # Get start time\n",
" start_time = time.time()\n",
" # Call the model\n",
" response = co.chat(\n",
" message=query,\n",
" search_queries_only=True,\n",
" )\n",
" # Get end time\n",
" end_time = time.time()\n",
" # Update total execution time (excluding sleep time)\n",
" total_time += (end_time - start_time)\n",
" # Compute total cost\n",
" rewritten_queries = [query.get('text', None) for query in response.search_queries]\n",
" input_cost = get_num_tokens(phrase=query) * 0.000001\n",
" output_cost = get_num_tokens(phrase=str(rewritten_queries)) * 0.000002\n",
" total_cost = input_cost + output_cost\n",
" # Wait for 1 seconds before the next iteration\n",
" time.sleep(1)\n",
"\n",
"# Calculate the average execution time\n",
"avg_time = total_time / num_iterations\n",
"\n",
"print(f\"Took {avg_time} seconds to rewrite the query.\")"
],
"metadata": {
"id": "Zc-w9vn_xAnY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten queries\n",
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n",
"print(rewritten_queries)"
],
"metadata": {
"id": "zdpHKYZXBIx0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute the costs\n",
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")"
],
"metadata": {
"id": "jv0eNVDiiYpt"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment