Created
September 12, 2024 23:44
-
-
Save pvlbzn/0ebf5f42a86cf6b94fef8d1cef86015e to your computer and use it in GitHub Desktop.
Code for Deep Guide to Large Language Models Function Calling and Structured Output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pprint import pp | |
from typing import Callable, List | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletion | |
GPT_MODEL = "gpt-4o" | |
client = OpenAI() | |
def get_prompt() -> str: | |
return ( | |
"You are a helpful assistant. Use provided functions if response is not clear." | |
) | |
def get_stock_price(ticker: str) -> str: | |
"""Get stock price implementation.""" | |
# Static data with tickers to price relation. Usually you'll have here | |
# some API call, SQL query, or some other data fetching call. | |
data = {"DJI": "40,345.41", "MSFT": "421.53", "AAPL": "225.89"} | |
return data[ticker] | |
def get_llm_functions() -> List[dict[str, str]]: | |
return [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_stock_price", | |
"description": "Get current stock index price", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"ticker": { | |
"type": "string", | |
# PART OF THE PROMPT! Careful what you write here | |
"description": "stock index ticker in format of TICKER, without ^", | |
} | |
}, | |
"required": ["ticker"], | |
}, | |
}, | |
} | |
] | |
def get_completion(messages: List[dict[str, str]], tools=None) -> ChatCompletion: | |
res = client.chat.completions.create( | |
model=GPT_MODEL, messages=messages, tools=tools | |
) | |
return res | |
def controller( | |
user_input: str, functions: dict[str, Callable] = None | |
) -> ChatCompletion: | |
# Fetch prompt, functions | |
prompt = get_prompt() | |
llm_functions = get_llm_functions() | |
# Set up first messages | |
messages = [ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": user_input}, | |
] | |
# Generate LLM response with messages and functions | |
# Prompt is already in the messages stored as system role. | |
completion = get_completion( | |
messages=messages, | |
tools=llm_functions, | |
) | |
# Verify if completion has `tool_calls` which is | |
# List[ChatCompletionMessageToolCall] or None | |
is_tool_call = completion.choices[0].message.tool_calls | |
if is_tool_call: | |
tool_call = completion.choices[0].message.tool_calls[0] | |
# We need call ID, and function out of it. ID has to be send back to LLM later | |
fn = functions[tool_call.function.name] | |
args = json.loads(tool_call.function.arguments) | |
# Call the function | |
res = fn(**args) | |
# Add messages. Both of them are essential for the correct call. | |
# Add assistant's response message | |
messages.append(completion.choices[0].message) | |
# Add function calling result | |
messages.append(dict(role="tool", tool_call_id=tool_call.id, content=res)) | |
# Run completion again to get the answer | |
tool_completion = get_completion(messages=messages) | |
pp(messages) | |
# Return response which was generated with help of function calling | |
return tool_completion.choices[0].message.content | |
# Return response without function calling | |
return completion.choices[0].message.content | |
if __name__ == "__main__": | |
available_functions = dict(get_stock_price=get_stock_price) | |
pp( | |
controller( | |
"What is the price of Dow Jones today?", | |
functions=available_functions, | |
) | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment