-
-
Save christo-olivier/032b278401f1f29fd52d553d8eba1c9d to your computer and use it in GitHub Desktop.
Function/tool calling with LLMs that dont support it in their API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######################################################## | |
# Requires `pip install modelsmith` | |
######################################################## | |
import inspect | |
import json | |
from typing import Any | |
from modelsmith import Forge | |
from pydantic import BaseModel, Field, TypeAdapter | |
from vertexai.language_models import TextGenerationModel | |
def weather(city: str) -> str: | |
""" | |
Provide the weather for a given city. | |
:param city: The city to get the weather for. | |
:output: The weather for the given city. | |
""" | |
print(f"Running weather function with city: `{city}`") | |
return f"The weather in {city} is 28 degrees celsius and overcast with no rain." | |
def news(city: str, category: str) -> str: | |
""" | |
Provide the latest news for a given city. | |
:param city: The city to get the news for. | |
:param category: The category to get the news for. | |
:output: The latest news for the given city. | |
""" | |
print(f"Running news function with city: `{city}` and category: `{category}`") | |
return ( | |
f"In {category} news for {city}, free chocolate was " | |
"available at the local supermarket. It caused an absolute riot!" | |
) | |
FUNCTION_MAP = { | |
"weather": weather, | |
"news": news, | |
} | |
FUNCTIONS_LLM = [ | |
{ | |
"name": "weather", | |
"description": inspect.getdoc(weather), | |
"parameters": TypeAdapter(weather).json_schema(), | |
}, | |
{ | |
"name": "news", | |
"description": inspect.getdoc(news), | |
"parameters": TypeAdapter(news).json_schema(), | |
}, | |
] | |
class LlmResponse(BaseModel): | |
message: str | None = Field( | |
description=( | |
"The message to send to the user. This " | |
"can be none if a function to call has been provided." | |
), | |
default=None, | |
) | |
function: str | None = Field( | |
description=( | |
"The name of the function to call. This can be " | |
"none if no function call is required and a response is provided." | |
) | |
) | |
arguments: dict[str, Any] | None = Field( | |
description=( | |
"The key is the name of the parameter and the value is the argument " | |
"to pass to it. This can be none if no function call is required and a " | |
"response is provided. All parameters must be provided." | |
) | |
) | |
PROMPT = inspect.cleandoc(""" | |
You are a general information agent that has access to function to provide | |
additional data to you. Check the input that you receive from the user, extract | |
any key entities you need. | |
If you do not have enough information in the prompt then DO NOT make up an answer, | |
instead look at the list of functions you can call to provide the information you | |
need to answer the user's question. | |
You have access to the following list of functions that you can call: | |
{{ functions }} | |
You MUST provide your response in the following JSON Schema below. You MUST take the | |
types of the OUTPUT SCHEMA into account and adjust your provided text to fit the | |
required types. | |
Here is the OUTPUT SCHEMA: | |
{{ response_model_json }} | |
If you have enough information to answer the question, then provide the answer in | |
the above JSON Schema, omitting the function name and arguments. DO NOT make up | |
any answers. Your answer must come from the information provided in this prompt. | |
{% if context is defined %} | |
{{ context }} | |
{% endif %} | |
User input: {{ user_input }} | |
""") | |
forge = Forge( | |
model=TextGenerationModel.from_pretrained("text-bison"), | |
response_model=LlmResponse, | |
prompt=PROMPT, | |
) | |
prompt_values = { | |
"functions": json.dumps(FUNCTIONS_LLM), | |
} | |
context = None | |
user_input = input("Please ask me a question: ") | |
# Safety mechanism to avoid going into an infinite loop | |
counter = 0 | |
while counter < 2: | |
if context: | |
prompt_values["context"] = context | |
response = forge.generate( | |
user_input=user_input, | |
prompt_values=prompt_values, | |
model_settings={ | |
"candidate_count": 1, | |
"max_output_tokens": 1024, | |
"temperature": 0.0, | |
"top_p": 1, | |
}, | |
) | |
if response.function: | |
function = FUNCTION_MAP[response.function] | |
context = function(**response.arguments) # type: ignore | |
if not response.function and response.message: | |
print(response.message) | |
break | |
counter += 1 | |
else: | |
print( | |
"I am having trouble answering your question. Please rephrase it or ask " | |
"something else." | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment