Created
May 20, 2024 13:36
-
-
Save jeffometer/daece8cd2a0424c96ed7c9a26b5f9d00 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fructose import Fructose | |
from typing import Optional, Annotated, Tuple | |
from pydantic import BaseModel, Field, model_validator, field_validator, AfterValidator, ValidationInfo | |
from typing_extensions import Self | |
import instructor | |
from openai import OpenAI | |
def call_graph_db(query: str) -> str: | |
try: | |
return memgraph(query, raise_on_error=True) | |
except Exception as e: | |
print(f"Error running query: {e}") | |
print(f"Original query: {query}") | |
print("Trying again") | |
# maybe include this line to self heal, or omit it to retry from scratch (which sometimes seems to work better) | |
# e.add_note(f" Original query: {query}") | |
raise e | |
SCHEMA = """ | |
<schema here> | |
""" | |
ai = Fructose() | |
@ai | |
def interpret_results(results) -> str: | |
"""Summarize the result to answer the rephrased question in 1-2 sentences.""" | |
@ai | |
def self_critique(original_question: str, query: str, result: str) -> tuple[str, float]: | |
"""You are an evaluator. Critique how well the result answers the question and if the query seems appropriate. Give a score in the range 0-1.""" | |
def confirm_rephrased_question(value): | |
print("\nRephrased question:", value) | |
user_feedback = input( | |
"\nDo you like the rephrased question? (press <enter> to accept, or provide feedback):\n\n>>> ") | |
if user_feedback: | |
raise ValueError(user_feedback) | |
return value | |
class GraphDatabaseResult(BaseModel): | |
rephrased_question: Annotated[str, AfterValidator(confirm_rephrased_question)] = Field(..., description="Rephrase the question in terms of the graph schema.") | |
reasoning: str = Field(..., description="Think step by step on how to make a graph database query that can answer this question.") | |
query: str = Field(..., description="The query to run") | |
result: str = Field(description="leave this blank", default="Pending...") | |
response: str = Field(description="leave this blank", default="Pending...") | |
certainty: Optional[float] = Field(description="leave this blank") | |
critique: Optional[str] = Field(description="leave this blank") | |
@model_validator(mode='after') | |
def try_to_run_query(self) -> Self: | |
self.result = call_graph_db(self.query) | |
self.response = interpret_results(self) | |
return self | |
@model_validator(mode='after') | |
def self_critique(self, info: ValidationInfo) -> Self: | |
original_question = info.context.get('original_question') | |
reason, score = self_critique(original_question, self.query, self.response) | |
if score < 0.5: | |
print(f"Original result: {self.result}") | |
print(f"The result was found to be inadequate. Score: {score} Reason: {reason}") | |
print("Trying again") | |
raise ValueError(f"The answer did not address the question sufficiently for the following reason: " + reason) | |
self.certainty = score | |
self.critique = reason | |
return self | |
def ask_the_graph(user_question: str) -> GraphDatabaseResult: | |
client = instructor.from_openai(OpenAI()) | |
return client.chat.completions.create( | |
model="gpt-4-turbo", | |
messages=[{"role": "system", "content": SCHEMA}, {"role": "user", "content": user_question}], | |
response_model=GraphDatabaseResult, | |
max_retries=3, | |
validation_context={"original_question": user_question} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment