Skip to content

Instantly share code, notes, and snippets.

@daaniyaan
Forked from jxnl/example_args.py
Created June 16, 2023 16:27
Show Gist options
  • Save daaniyaan/e9b6e4667eb450241a134a344f0ec7f1 to your computer and use it in GitHub Desktop.
Save daaniyaan/e9b6e4667eb450241a134a344f0ec7f1 to your computer and use it in GitHub Desktop.
Power of pydantic
import openai
import openai_function
@openai_function
def sum(a:int, b:int) -> int:
"""Sum description adds a + b"""
return a + b
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
temperature=0,
functions=[sum.schema],
messages=[
{
"role": "system",
"content": "You must use the `sum` function instead of adding yourself.",
},
{
"role": "user",
"content": "What is 6+3 use the `sum` function",
},
],
)
sum.from_response(completion)
import openai
from pydantic import BaseModel, Field
class Arguments(BaseModel):
"""Arugments to add"""
a: int = Field(..., description="First argument")
b: int = Field(..., description="Second argument")
@openai_function
def sum(args:Arguments) -> int:
"""Sum description adds a + b"""
return args.a + args.b
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
temperature=0,
functions=[sum.schema],
messages=[
{
"role": "system",
"content": "You must use the `sum` function instead of adding yourself.",
},
{
"role": "user",
"content": "What is 6+3+2 use the `sum` function",
},
],
)
sum.from_response(completion)
from pydantic import validate_arguments
import json
from functools import wraps
from typing import Any, Callable
class openai_function:
def __init__(self, func: Callable) -> None:
self.func = func
self.validate_func = validate_arguments(func)
self.schema = {
"name": self.func.__name__,
"description": self.func.__doc__,
"parameters": self.validate_func.model.schema(),
}
self.model = self.validate_func.model
def __call__(self, *args: Any, **kwargs: Any) -> Any:
@wraps(self.func)
def wrapper(*args, **kwargs):
return self.validate_func(*args, **kwargs)
return wrapper(*args, **kwargs)
def from_response(self, completion, throw_error=True):
"""Execute the function from the response of an openai chat completion"""
message = completion.choices[0].message
if throw_error:
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == self.schema["name"], "Function name does not match"
function_call = message["function_call"]
arguments = json.loads(function_call["arguments"])
return self.validate_func(**arguments)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment