Skip to content

Instantly share code, notes, and snippets.

@seanchatmangpt
Last active April 14, 2024 20:11
Show Gist options
  • Save seanchatmangpt/6a897ae294c556cd220bc7cc617191ab to your computer and use it in GitHub Desktop.
Save seanchatmangpt/6a897ae294c556cd220bc7cc617191ab to your computer and use it in GitHub Desktop.
Convert your prompt into a pydantic instance.
import ast
import logging
import inspect
from typing import Type, TypeVar
from dspy import Assert, Module, ChainOfThought, Signature, InputField, OutputField
from pydantic import BaseModel, ValidationError
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)
def eval_dict_str(dict_str: str) -> dict:
"""Safely convert str to dict"""
return ast.literal_eval(dict_str)
class PromptToPydanticInstanceSignature(Signature):
"""Synthesize the prompt into the kwargs fit the model"""
root_pydantic_model_class_name = InputField(
desc="The class name of the pydantic model to receive the kwargs"
)
pydantic_model_definitions = InputField(
desc="Pydantic model class definitions as a string"
)
prompt = InputField(desc="The prompt to be synthesized into data")
root_model_kwargs_dict = OutputField(
prefix="kwargs_dict = ",
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.",
)
class PromptToPydanticInstanceErrorSignature(PromptToPydanticInstanceSignature):
"""Synthesize the prompt into the kwargs fit the model"""
error = InputField(desc="Error message to fix the kwargs")
T = TypeVar('T', bound=BaseModel)
class GenPydanticInstance(Module):
"""
A module for generating and validating Pydantic model instances based on prompts.
Usage:
To use this module, instantiate the GenPydanticInstance class with the desired
root Pydantic model and optional child models. Then, call the `forward` method
with a prompt to generate Pydantic model instances based on the provided prompt.
"""
def __init__(
self, root_model: Type[T], child_models: list[Type[BaseModel]] = None
):
super().__init__()
if not issubclass(root_model, BaseModel):
raise TypeError("root_model must inherit from pydantic.BaseModel")
self.models = [root_model] # Always include root_model in models list
if child_models:
# Validate that each child_model inherits from BaseModel
for model in child_models:
if not issubclass(model, BaseModel):
raise TypeError(
"All child_models must inherit from pydantic.BaseModel"
)
self.models.extend(
child_models
)
self.output_key = "root_model_kwargs_dict"
self.root_model = root_model
# Concatenate source code of models for use in generation/correction logic
self.model_sources = "\n".join(
[inspect.getsource(model) for model in self.models]
)
# Initialize DSPy ChainOfThought modules for generation and correction
self.generate = ChainOfThought(PromptToPydanticInstanceSignature)
self.correct_generate = ChainOfThought(PromptToPydanticInstanceErrorSignature)
def validate_root_model(self, output: str) -> bool:
"""Validates whether the generated output conforms to the root Pydantic model."""
try:
model_inst = self.root_model.model_validate(eval_dict_str(output))
return isinstance(model_inst, self.root_model)
except (ValidationError, ValueError, TypeError, SyntaxError) as error:
return False
def validate_output(self, output) -> T:
"""Validates the generated output and returns an instance of the root Pydantic model if successful."""
Assert(
self.validate_root_model(output),
f"""You need to create a kwargs dict for {self.root_model.__name__}""",
)
return self.root_model.model_validate(eval_dict_str(output))
def forward(self, prompt) -> T:
"""
Takes a prompt as input and generates a Python dictionary that represents an instance of the
root Pydantic model. It also handles error correction and validation.
"""
output = self.generate(
prompt=prompt,
root_pydantic_model_class_name=self.root_model.__name__,
pydantic_model_definitions=self.model_sources,
)[self.output_key]
try:
return self.validate_output(output)
except (AssertionError, ValueError, TypeError) as error:
logger.error(f"Error {str(error)}\nOutput:\n{output}")
# Correction attempt
corrected_output = self.generate(
prompt=prompt,
root_pydantic_model_class_name=self.root_model.__name__,
pydantic_model_definitions=self.model_sources,
error=str(error),
)[self.output_key]
return self.validate_output(corrected_output)
import dspy
from dspy import Module
from pydantic import BaseModel, Field
from typing import List, Optional
from rdddy.generators.gen_pydantic_instance import GenPydanticInstance
class GraphNode(BaseModel):
id: str = Field(..., description="Unique identifier for the node")
content: str = Field(
..., description="Content or question associated with the node"
)
answer: Optional[str] = Field(
None, description="Answer or result of the node's reasoning step"
)
class GraphEdge(BaseModel):
source_id: str = Field(..., description="Source node ID")
target_id: str = Field(..., description="Target node ID")
relationship: str = Field(
..., description="Description of the relationship or reasoning link"
)
class GraphOfThoughtModel(BaseModel):
nodes: List[GraphNode] = Field(..., description="List of nodes in the graph")
edges: List[GraphEdge] = Field(..., description="List of edges linking the nodes")
class GraphOfThought(Module):
def __init__(self):
super().__init__()
def forward(self, prompt) -> GraphOfThoughtModel:
return GenPydanticInstance(
root_model=GraphOfThoughtModel, child_models=[GraphNode, GraphEdge]
).forward(prompt)
def main():
lm = dspy.OpenAI(max_tokens=1000)
dspy.settings.configure(lm=lm)
prompt = "Decision Model Notation for cancer diagnosis"
# prompt = "BPMN for ordering a sandwich"
# prompt = "Explain the water cycle step by step."
result_graph = GraphOfThought().forward(prompt)
print(result_graph)
if __name__ == "__main__":
main()
from rdddy.generators.gen_pydantic_instance import (
GenPydanticInstance,
)
import pytest
from unittest.mock import patch, MagicMock
from dspy import settings, OpenAI, DSPyAssertionError
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field, ValidationError
class APIEndpoint(BaseModel):
method: str = Field(..., description="HTTP method of the API endpoint")
url: str = Field(..., description="URL of the API endpoint")
description: str = Field(
..., description="Description of what the API endpoint does"
)
response: str = Field(..., description="Response from the API endpoint")
query_params: Optional[Dict[str, Any]] = Field(None, description="Query parameters")
VALID_PYDANTIC_MODEL_STRING = """{
"method": "GET",
"url": "/forecast/today",
"description": "API endpoint for retrieving meteorological conditions",
"response": "Structured summary of weather conditions",
"query_params": {"geographical_area": "string"}
}"""
VALID_PROMPT = """
Imagine a digital portal where users can inquire about meteorological conditions.
This portal is accessible through a web interface that interacts with a backend service.
The service is invoked by sending a request to a specific endpoint.
This request is crafted using a standard protocol for web communication.
The endpoint's location is a mystery, hidden within the path '/forecast/today'.
Users pose their inquiries by specifying a geographical area of interest,
though the exact format of this specification is left to the user's imagination.
Upon successful request processing, the service responds with a structured
summary of the weather, encapsulating details such as temperature, humidity,
and wind speed. However, the structure of this response and the means of
accessing the weather summary are not explicitly defined.
"""
VALID_PYDANTIC_MODEL_DICT = {
"method": "GET",
"url": "/forecast/today",
"description": "API endpoint for retrieving meteorological conditions",
"response": "Structured summary of weather conditions",
"query_params": {"geographical_area": "string"},
}
INVALID_STR = "{ 'name': 'Alice', 'age': 30, 'city': 'Wonderland' }"
@pytest.fixture
def gen_pydantic_model():
with patch.object(settings, "configure"), patch.object(
OpenAI, "__init__", return_value=None
):
yield GenPydanticInstance(
APIEndpoint
) # Replace APIEndpoint with your Pydantic model
@patch("dspy.predict.Predict.forward")
@patch("rdddy.generators.gen_module.ChainOfThought")
@patch("ast.literal_eval")
def test_forward_success(
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_pydantic_model
):
# Mock responses for a successful forward pass
mock_predict.return_value.get.return_value = (
VALID_PYDANTIC_MODEL_STRING # Replace with a valid string for your model
)
mock_chain_of_thought.return_value.get.return_value = VALID_PYDANTIC_MODEL_STRING
mock_literal_eval.return_value = (
VALID_PYDANTIC_MODEL_DICT # Replace with a valid dict for your model
)
# Call the method
result = gen_pydantic_model.forward(
prompt=VALID_PROMPT
) # Replace with a valid prompt
assert isinstance(
result, APIEndpoint
) # Replace APIEndpoint with your Pydantic model class
@patch("dspy.predict.Predict.forward")
@patch("rdddy.generators.gen_module.ChainOfThought")
@patch("ast.literal_eval", side_effect=SyntaxError)
def test_forward_syntax_error(
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_pydantic_model
):
# Setup mock responses for a syntax error case
mock_predict.return_value.get.return_value = INVALID_STR
mock_chain_of_thought.side_effect = [
MagicMock(get=MagicMock(return_value=INVALID_STR)), # initial call
MagicMock(get=MagicMock(return_value=INVALID_STR)), # correction call
]
# Call the method and expect an error
with pytest.raises(DSPyAssertionError):
gen_pydantic_model.forward(prompt="///") # Replace with an invalid prompt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment