Skip to content

Instantly share code, notes, and snippets.

@skrawcz
Created June 19, 2024 22:56
Show Gist options
  • Save skrawcz/05d85faebb905cba36cbe1f37a5c155d to your computer and use it in GitHub Desktop.
Save skrawcz/05d85faebb905cba36cbe1f37a5c155d to your computer and use it in GitHub Desktop.
Gist for the Hamilton, Burr, FalkorDB blog post
@action(
reads=["chat_history"],
writes=["chat_history"],
)
def AI_generate_response(state: State, client: openai.Client) -> tuple[dict, State]:
"""AI step to generate the response."""
messages = state["chat_history"]
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=messages,
) # get a new response from the model where it can see the function response
response_message = response.choices[0].message
new_state = state.append(chat_history=response_message.to_dict())
return {"ai_response": response_message.content,
"usage": response.usage.to_dict()}, new_state
@action(
reads=["question", "chat_history"],
writes=["chat_history", "tool_calls"],
)
def AI_create_cypher_query(state: State,
client: openai.Client) -> tuple[dict, State]:
"""AI step to create the cypher query."""
messages = state["chat_history"]
# Call the function
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=messages,
tools=[run_cypher_query_tool_description],
tool_choice="auto",
)
response_message = response.choices[0].message
new_state = state.append(chat_history=response_message.to_dict())
tool_calls = response_message.tool_calls
if tool_calls:
new_state = new_state.update(tool_calls=tool_calls)
return {"ai_response": response_message.content,
"usage": response.usage.to_dict()}, new_state
@action(reads=["count"], writes=["count"])
def counter(state: State) -> State:
return state.update(counter=state.get("count", 0) +1)
from burr.core import ApplicationBuilder, default, expr
app = (
ApplicationBuilder()
.with_actions(
count=count,
done=done # implementation left out above
).with_transitions(
("counter", "counter", expr("count < 10")), # Keep counting if the counter is less than 10
("counter", "done", default) # Otherwise, we're done
).with_state(count=0)
.with_entrypoint("counter") # we have to start somewhere
.build()
)
burr_application = (
ApplicationBuilder()
.with_actions( # define the actions
AI_create_cypher_query.bind(client=openai_client),
tool_call.bind(graph=graph),
AI_generate_response.bind(client=openai_client),
human_converse
)
.with_transitions( # define the edges between the actions based on state conditions
("human_converse", "AI_create_cypher_query", default),
("AI_create_cypher_query", "tool_call", expr("len(tool_calls)>0")),
("AI_create_cypher_query", "human_converse", default),
("tool_call", "AI_generate_response", default),
("AI_generate_response", "human_converse", default)
)
.with_identifiers(app_id=application_run_id)
.with_state( # initial state
**{"chat_history": base_messages, "tool_calls": []},
)
.with_entrypoint("human_converse")
.with_tracker(tracker)
.build()
)
def run_cypher_query(graph, query):
try:
results = graph.ro_query(query).result_set
except:
results = {"error": "Query failed please try a different variation of this query"}
if len(results) == 0:
results = {
"error": "The query did not return any data, please make sure you're using the right edge "
"directions and you're following the correct graph schema"}
return str(results)
run_cypher_query_tool_description = {
"type": "function",
"function": {
"name": "run_cypher_query",
"description": "Runs a Cypher query against the knowledge graph",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Query to execute",
},
},
"required": ["query"],
},
},
}
from hamilton import driver
import definitions # contains node definitions, e.g. A, B, C from above
dr = driver.Builder().with_modules(definitions).build()
# request node named "C"; returns a dictionary of results
results = dr.execute(["C"], inputs={"external_input": 7})
# request node named "B"; returns a dictionary of results
results = dr.execute(["B"], inputs={"external_input": 7})
# request node named "B"; returns a dictionary of results
results = dr.execute(["A", "B", "C"], inputs={"external_input": 7})
while True:
question = input("What can I help you with?\n")
if question == "exit":
break
print(f"Human: {question}")
action, _, state = burr_application.run(
halt_before=["human_converse"],
inputs={"user_question": question},
)
print(f"AI: {state['chat_history'][-1]['content']}\n")
def set_inital_chat_history(schema_prompt: str) -> list[dict]:
SYSTEM_MESSAGE = "You are a Cypher expert with access to a directed knowledge graph\n"
SYSTEM_MESSAGE += schema_prompt
SYSTEM_MESSAGE += ("Query the knowledge graph to extract relevant information to help you answer the users "
"questions, base your answer only on the context retrieved from the knowledge graph, "
"do not use preexisting knowledge.")
SYSTEM_MESSAGE += ("For example to find out if two fighters had fought each other e.g. did Conor McGregor "
"every compete against Jose Aldo issue the following query: "
"MATCH (a:Fighter)-[]->(f:Fight)<-[]-(b:Fighter) WHERE a.Name = 'Conor McGregor' AND "
"b.Name = 'Jose Aldo' RETURN a, b\n")
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
return messages
@action(
reads=["tool_calls", "chat_history"],
writes=["tool_calls", "chat_history"],
)
def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]:
"""Tool call step -- execute the tool call."""
tool_calls = state.get("tool_calls", [])
new_state = state
result = {"tool_calls": []}
for tool_call in tool_calls:
function_name = tool_call.function.name
assert (function_name == "run_cypher_query")
function_args = json.loads(tool_call.function.arguments)
function_response = run_cypher_query(graph, function_args.get("query"))
new_state = new_state.append(chat_history=
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
)
result["tool_calls"].append(
{"tool_call_id": tool_call.id, "response": function_response})
new_state = new_state.update(tool_calls=[])
return result, new_state
def write_to_graph(record: Collect[dict], graph: falkordb.Graph) -> int:
"""Take all records and then push to the DB"""
records = list(record)
# Load all fighters in one go.
q = "UNWIND $fighters as fighter CREATE (f:Fighter) SET f = fighter"
graph.query(q, {'fighters': records})
return len(records)
q = "MERGE (:Referee {Name: $name})"
_graph.query(q,
{'name': _row.Referee
if isinstance(_row.Referee, str) else ""})
q = "MERGE (c:Card {Date: $date, Location: $location})"
_graph.query(q, {'date': _row.date, 'location': _row.location})
q = """MATCH (c:Card {Date: $date, Location: $location})
MATCH (ref:Referee {Name: $referee})
MATCH (r:Fighter {Name:$R_fighter})
MATCH (b:Fighter {Name:$B_fighter})
CREATE (f:Fight)-[:PART_OF]->(c)
SET f = $fight
CREATE (f)-[:RED]->(r)
CREATE (f)-[:BLUE]->(b)
CREATE (ref)-[:REFEREED]->(f)
RETURN ID(f)
"""
f_id = _graph.query(q,
{'date': _row.date,
'location': _row.location,
'referee': _row.Referee
if isinstance(_row.Referee, str) else "",
'R_fighter': _row.R_fighter,
'B_fighter': _row.B_fighter,
'fight': {'Last_round': _row.last_round,
'Last_round_time': _row.last_round_time,
'Format': _row.Format,
'Fight_type': _row.Fight_type}
}
).result_set[0][0]
q = """MATCH (f:Fight) WHERE ID(f) = $fight_id
MATCH (l:Fighter {Name:$loser})
MATCH (w:Fighter {Name:$winner})
CREATE (w)-[:WON]->(f), (l)-[:LOST]->(f)
"""
_graph.query(q,
{'fight_id': f_id,
'loser': _row.Loser,
'winner': _row.Winner
if isinstance(_row.Winner, str) else ""}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment