Created
April 29, 2024 00:08
-
-
Save lukacat10/1cbf1ce23750a5aa03a7b68b4220e5b0 to your computer and use it in GitHub Desktop.
Generate a schema out of a python function automatically for open ai function calling api. Uses the docstring and the typing of the function in order to generate the schema. Doesn't work on every possible edge case at the moment.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum, EnumType | |
import inspect | |
import json | |
from types import NoneType | |
from typing import Any, Dict, get_args, get_origin | |
from docstring_parser import DocstringParam, parse | |
PARAM_TYPE_RESOLUTION_TABLE = { | |
str: "string", | |
int: "integer", | |
float: "float", | |
EnumType: "string", | |
NoneType: "null" | |
} | |
def get_function_description(func: Any) -> Dict[str, Any]: | |
""" | |
Generate JSON description for a given function. | |
:param func: The function for which to generate the description. | |
:return: A dictionary representing the JSON description. | |
""" | |
# Get function name | |
name = func.__name__ | |
sig = inspect.signature(func) | |
# Get function docstring | |
docstring = inspect.getdoc(func) | |
if docstring is None: | |
raise ValueError("Function must have a docstring.") | |
# Parse docstring | |
parsed_docstring = parse(docstring) | |
# Construct JSON output | |
output = { | |
"name": name, | |
"description": parsed_docstring.description, | |
"parameters": {"type": "object", "properties": {}}, | |
} | |
parameters = output["parameters"]["properties"] | |
print(parsed_docstring.params) | |
# Add parameters to JSON | |
for param in sig.parameters.values(): | |
name = param.name | |
param_type = param.annotation | |
docstring_param: DocstringParam | None = next( | |
filter( | |
lambda docstring_arg: docstring_arg.arg_name == name, | |
parsed_docstring.params, | |
), | |
None, | |
) | |
parameters[name] = {} | |
if docstring_param is not None: | |
parameters[name]["description"] = docstring_param.description | |
if param_type == list or get_origin(param_type) == list: | |
parameters[name]["type"] = "array" | |
parameters[name]["items"] = { | |
"type": PARAM_TYPE_RESOLUTION_TABLE[get_args(param_type)[0]] | |
} | |
elif isinstance(param_type, EnumType): | |
parameters[name]["type"] = "string" | |
parameters[name]["enum"] = [ | |
str(member.value) for member in param_type.__members__.values() | |
] | |
elif ( | |
get_args(param_type) != () | |
): | |
type_list = [PARAM_TYPE_RESOLUTION_TABLE[item] for item in get_args(param_type)] | |
parameters[name]["type"] = type_list | |
else: | |
parameters[name]["type"] = PARAM_TYPE_RESOLUTION_TABLE[param_type] | |
output["parameters"]["required"] = [ | |
param.name | |
for param in sig.parameters.values() | |
if param.default == inspect._empty | |
] | |
# Add return type if available | |
# if parsed_docstring.returns: | |
# output["return"] = { | |
# "type": "string", | |
# "description": parsed_docstring.returns.description | |
# } | |
return output | |
# Example functions | |
class Unit(Enum): | |
Celsius = "c" | |
Farenheit = "f" | |
def get_weather(location: str, unit: Unit = "c", fuck: int | None = None): | |
""" | |
Determine weather in my location | |
:param location: The city and state e.g. San Francisco, CA | |
""" | |
def calculate_area(length: float, width: float) -> float: | |
""" | |
Calculate the area of a rectangle. | |
:param length: The length of the rectangle. | |
:param width: The width of the rectangle. | |
:return: The area of the rectangle. | |
""" | |
return length * width | |
def find_max(numbers: list[int]) -> int: | |
""" | |
Find the maximum element in a list. | |
:param numbers: A list of integers. | |
:return: The maximum element in the list. | |
""" | |
return max(numbers) | |
def reverse_string(s: str) -> str: | |
""" | |
Reverse a given string. | |
:param s: The input string. | |
:return: The reversed string. | |
""" | |
return s[::-1] | |
def is_palindrome(s: str) -> bool: | |
""" | |
Check if a given string is a palindrome. | |
:param s: The input string. | |
:return: True if the string is a palindrome, False otherwise. | |
""" | |
return s == s[::-1] | |
# Example usage | |
functions = [get_weather, calculate_area, find_max, reverse_string, is_palindrome] | |
for func in functions: | |
description = get_function_description(func) | |
print(json.dumps(description, indent=2)) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I plan to make this a proper tool. For now, it's just a script.