Created
November 12, 2023 14:06
-
-
Save assafelovic/579822cd42d52d80db1e1c1ff82ffffd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### This Gist demos how to use the latest OpenAI Assistants API with Internet access | |
# Step 1: Upgrade to Python SDK v1.2 with pip install --upgrade openai | |
# Step 2: Install Tavily Python SDK with pip install tavily-python | |
# Step 3: Build an OpenAI assistant with Python SDK documentation - https://platform.openai.com/docs/assistants/overview | |
import os | |
import json | |
import time | |
from openai import OpenAI | |
from tavily import TavilyClient | |
# Initialize clients with API keys | |
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
tavily_client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) | |
assistant_prompt_instruction = """You are a finance expert. | |
Your goal is to provide answers based on information from the internet. | |
You must use the provided Tavily search API function to find relevant online information. | |
You should never use your own knowledge to answer questions. | |
Please include relevant url sources in the end of your answers. | |
""" | |
# Function to perform a Tavily search | |
def tavily_search(query): | |
search_result = tavily_client.get_search_context(query, search_depth="advanced", max_tokens=8000) | |
return search_result | |
# Function to wait for a run to complete | |
def wait_for_run_completion(thread_id, run_id): | |
while True: | |
time.sleep(1) | |
run = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id) | |
print(f"Current run status: {run.status}") | |
if run.status in ['completed', 'failed', 'requires_action']: | |
return run | |
# Function to handle tool output submission | |
def submit_tool_outputs(thread_id, run_id, tools_to_call): | |
tool_output_array = [] | |
for tool in tools_to_call: | |
output = None | |
tool_call_id = tool.id | |
function_name = tool.function.name | |
function_args = tool.function.arguments | |
if function_name == "tavily_search": | |
output = tavily_search(query=json.loads(function_args)["query"]) | |
if output: | |
tool_output_array.append({"tool_call_id": tool_call_id, "output": output}) | |
return client.beta.threads.runs.submit_tool_outputs( | |
thread_id=thread_id, | |
run_id=run_id, | |
tool_outputs=tool_output_array | |
) | |
# Function to print messages from a thread | |
def print_messages_from_thread(thread_id): | |
messages = client.beta.threads.messages.list(thread_id=thread_id) | |
for msg in messages: | |
print(f"{msg.role}: {msg.content[0].text.value}") | |
# Create an assistant | |
assistant = client.beta.assistants.create( | |
instructions=assistant_prompt_instruction, | |
model="gpt-4-1106-preview", | |
tools=[{ | |
"type": "function", | |
"function": { | |
"name": "tavily_search", | |
"description": "Get information on recent events from the web.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": {"type": "string", "description": "The search query to use. For example: 'Latest news on Nvidia stock performance'"}, | |
}, | |
"required": ["query"] | |
} | |
} | |
}] | |
) | |
assistant_id = assistant.id | |
print(f"Assistant ID: {assistant_id}") | |
# Create a thread | |
thread = client.beta.threads.create() | |
print(f"Thread: {thread}") | |
# Ongoing conversation loop | |
while True: | |
user_input = input("You: ") | |
if user_input.lower() == 'exit': | |
break | |
# Create a message | |
message = client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=user_input, | |
) | |
# Create a run | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant_id, | |
) | |
print(f"Run ID: {run.id}") | |
# Wait for run to complete | |
run = wait_for_run_completion(thread.id, run.id) | |
if run.status == 'failed': | |
print(run.error) | |
continue | |
elif run.status == 'requires_action': | |
run = submit_tool_outputs(thread.id, run.id, run.required_action.submit_tool_outputs.tool_calls) | |
run = wait_for_run_completion(thread.id, run.id) | |
# Print messages from the thread | |
print_messages_from_thread(thread.id) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment