Skip to content

Instantly share code, notes, and snippets.

@virattt
Created February 19, 2024 22:45
Show Gist options
  • Save virattt/9099bf1a32ff2b99383b2fabba0ae763 to your computer and use it in GitHub Desktop.
Save virattt/9099bf1a32ff2b99383b2fabba0ae763 to your computer and use it in GitHub Desktop.
query_expansion-cohere-gpt.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyO9f8+2MJPoRNjLkLbs2CcQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/9099bf1a32ff2b99383b2fabba0ae763/query_expansion-cohere-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRA6TPl0lXqq"
},
"outputs": [],
"source": [
"!pip install cohere"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your Cohere API key\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "4fJZ05n9lflZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"query = \"What are the revenues of Airbnb, Booking, and Expedia?\""
],
"metadata": {
"id": "5UbW7pydl82v"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use Cohere's `command` model for query expansion"
],
"metadata": {
"id": "025xZuDOlnT6"
}
},
{
"cell_type": "code",
"source": [
"import cohere\n",
"\n",
"# Get your cohere API key on: www.cohere.com\n",
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n",
"\n",
"# Call the model\n",
"response = co.chat(\n",
" message=query,\n",
" search_queries_only=True,\n",
")"
],
"metadata": {
"id": "b8-Ua66vlsWJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n",
"print(rewritten_queries)"
],
"metadata": {
"id": "QTXJByrxl3E_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use Cohere's `command-light` model for query expansion"
],
"metadata": {
"id": "oYnx5WtEmXz7"
}
},
{
"cell_type": "code",
"source": [
"# Get your cohere API key on: www.cohere.com\n",
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n",
"\n",
"# Call the model\n",
"response = co.chat(\n",
" model='command-light',\n",
" message=query,\n",
" search_queries_only=True,\n",
")"
],
"metadata": {
"id": "lNXcXI9ymJZa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n",
"print(rewritten_queries)"
],
"metadata": {
"id": "MaFJ6lrbmfmQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use GPT-4 for query expansion"
],
"metadata": {
"id": "bPzoWQhVAmLt"
}
},
{
"cell_type": "code",
"source": [
"!pip install openai"
],
"metadata": {
"id": "D_LA58dWgVT3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "WluIGzCFgXlb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"prompt = \"\"\"\n",
"You are a helpful assistant that expands a user query into sub-queries.\n",
"The sub-queries should be mutually exclusive and collectively exhaustive.\n",
"Your response will be a JSON object with a `queries` field, which is a list of `query` objects.\n",
"\"\"\""
],
"metadata": {
"id": "9Q4iyvUZgmKO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"import time\n",
"import json\n",
"\n",
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
"\n",
"# Call the model\n",
"response = client.chat.completions.create(\n",
" model='gpt-4-0125-preview',\n",
" response_format={\"type\": \"json_object\"},\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
")"
],
"metadata": {
"id": "Z83h16UuMlMt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_query = response.choices[0].message.content\n",
"print(rewritten_query)"
],
"metadata": {
"id": "8VZMWffzm0-i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Use GPT-3.5 to rewrite the query"
],
"metadata": {
"id": "AdrLmbzAAsgX"
}
},
{
"cell_type": "code",
"source": [
"# Call the model\n",
"response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo-0125',\n",
" response_format={\"type\": \"json_object\"},\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": query},\n",
" ]\n",
")"
],
"metadata": {
"id": "AdpynLvNAvww"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the rewritten query\n",
"rewritten_query = response.choices[0].message.content\n",
"print(rewritten_query)"
],
"metadata": {
"id": "yFKleKcWD84S"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment