Skip to content

Instantly share code, notes, and snippets.

@lalanikarim
Created June 2, 2024 06:28
Show Gist options
  • Save lalanikarim/99fd6b1db3346a744246d6a2cab72c04 to your computer and use it in GitHub Desktop.
Save lalanikarim/99fd6b1db3346a744246d6a2cab72c04 to your computer and use it in GitHub Desktop.
import os
import shutil
import sqlite3
import pandas as pd
import requests
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
local_file = "travel2.sqlite"
# The backup lets us restart for each tutorial section
backup_file = "travel2.backup.sqlite"
overwrite = False
if overwrite or not os.path.exists(local_file):
response = requests.get(db_url)
response.raise_for_status() # Ensure the request was successful
with open(local_file, "wb") as f:
f.write(response.content)
# Backup - we will use this to "reset" our DB in each section
shutil.copy(local_file, backup_file)
# Convert the flights to present time for our tutorial
conn = sqlite3.connect(local_file)
cursor = conn.cursor()
tables = pd.read_sql(
"SELECT name FROM sqlite_master WHERE type='table';", conn
).name.tolist()
tdf = {}
for t in tables:
tdf[t] = pd.read_sql(f"SELECT * from {t}", conn)
example_time = pd.to_datetime(
tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
).max()
current_time = pd.to_datetime("now").tz_localize(example_time.tz)
time_diff = current_time - example_time
tdf["bookings"]["book_date"] = (
pd.to_datetime(tdf["bookings"]["book_date"].replace("\\N", pd.NaT), utc=True)
+ time_diff
)
datetime_columns = [
"scheduled_departure",
"scheduled_arrival",
"actual_departure",
"actual_arrival",
]
for column in datetime_columns:
tdf["flights"][column] = (
pd.to_datetime(tdf["flights"][column].replace("\\N", pd.NaT)) + time_diff
)
for table_name, df in tdf.items():
df.to_sql(table_name, conn, if_exists="replace", index=False)
del df
del tdf
conn.commit()
conn.close()
db = local_file # We'll be using this local file as our DB in this tutorial
import re
import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.tools import tool
response = requests.get(
"https://storage.googleapis.com/benchmarks-artifacts/travel-db/swiss_faq.md"
)
response.raise_for_status()
faq_text = response.text
docs = [{"page_content": txt} for txt in re.split(r"(?=\n##)", faq_text)]
class VectorStoreRetriever:
def __init__(self, docs: list, vectors: list, embed_client):
self._arr = np.array(vectors)
self._docs = docs
self._client = embed_client
@classmethod
def from_docs(cls, docs, embed_client):
embeddings = embed_client.embed_documents(
[doc["page_content"] for doc in docs]
)
#vectors = [emb.embedding for emb in embeddings.data]
return cls(docs, embeddings, embed_client)
def query(self, query: str, k: int = 5) -> list[dict]:
embed = self._client.embed_query(
query
)
# "@" is just a matrix multiplication in python
scores = np.array(embed) @ self._arr.T
top_k_idx = np.argpartition(scores, -k)[-k:]
top_k_idx_sorted = top_k_idx[np.argsort(-scores[top_k_idx])]
return [
{**self._docs[idx], "similarity": scores[idx]} for idx in top_k_idx_sorted
]
embed = OllamaEmbeddings(model="nomic-embed-text")
retriever = VectorStoreRetriever.from_docs(docs, embed)
@tool
def lookup_policy(query: str) -> str:
"""Consult the company policies to check whether certain options are permitted.
Use this before making any flight changes performing other 'write' events."""
docs = retriever.query(query, k=2)
return "\n\n".join([doc["page_content"] for doc in docs])
import sqlite3
from datetime import date, datetime
from typing import Optional
import pytz
from langchain_core.runnables import ensure_config
@tool
def fetch_user_flight_information() -> list[dict]:
"""Fetch all tickets for the user along with corresponding flight information and seat assignments.
Returns:
A list of dictionaries where each dictionary contains the ticket details,
associated flight details, and the seat assignments for each ticket belonging to the user.
"""
config = ensure_config() # Fetch from the context
configuration = config.get("configurable", {})
passenger_id = configuration.get("passenger_id", None)
if not passenger_id:
raise ValueError("No passenger ID configured.")
conn = sqlite3.connect(db)
cursor = conn.cursor()
query = """
SELECT
t.ticket_no, t.book_ref,
f.flight_id, f.flight_no, f.departure_airport, f.arrival_airport, f.scheduled_departure, f.scheduled_arrival,
bp.seat_no, tf.fare_conditions
FROM
tickets t
JOIN ticket_flights tf ON t.ticket_no = tf.ticket_no
JOIN flights f ON tf.flight_id = f.flight_id
JOIN boarding_passes bp ON bp.ticket_no = t.ticket_no AND bp.flight_id = f.flight_id
WHERE
t.passenger_id = ?
"""
cursor.execute(query, (passenger_id,))
rows = cursor.fetchall()
column_names = [column[0] for column in cursor.description]
results = [dict(zip(column_names, row)) for row in rows]
cursor.close()
conn.close()
return results
@tool
def search_flights(
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
start_time: Optional[date | datetime] = None,
end_time: Optional[date | datetime] = None,
limit: int = 20,
) -> list[dict]:
"""Search for flights based on departure airport, arrival airport, and departure time range."""
conn = sqlite3.connect(db)
cursor = conn.cursor()
query = "SELECT * FROM flights WHERE 1 = 1"
params = []
if departure_airport:
query += " AND departure_airport = ?"
params.append(departure_airport)
if arrival_airport:
query += " AND arrival_airport = ?"
params.append(arrival_airport)
if start_time:
query += " AND scheduled_departure >= ?"
params.append(start_time)
if end_time:
query += " AND scheduled_departure <= ?"
params.append(end_time)
query += " LIMIT ?"
params.append(limit)
cursor.execute(query, params)
rows = cursor.fetchall()
column_names = [column[0] for column in cursor.description]
results = [dict(zip(column_names, row)) for row in rows]
cursor.close()
conn.close()
return results
@tool
def update_ticket_to_new_flight(ticket_no: str, new_flight_id: int) -> str:
"""Update the user's ticket to a new valid flight."""
config = ensure_config()
configuration = config.get("configurable", {})
passenger_id = configuration.get("passenger_id", None)
if not passenger_id:
raise ValueError("No passenger ID configured.")
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute(
"SELECT departure_airport, arrival_airport, scheduled_departure FROM flights WHERE flight_id = ?",
(new_flight_id,),
)
new_flight = cursor.fetchone()
if not new_flight:
cursor.close()
conn.close()
return "Invalid new flight ID provided."
column_names = [column[0] for column in cursor.description]
new_flight_dict = dict(zip(column_names, new_flight))
timezone = pytz.timezone("Etc/GMT-3")
current_time = datetime.now(tz=timezone)
departure_time = datetime.strptime(
new_flight_dict["scheduled_departure"], "%Y-%m-%d %H:%M:%S.%f%z"
)
time_until = (departure_time - current_time).total_seconds()
if time_until < (3 * 3600):
return f"Not permitted to reschedule to a flight that is less than 3 hours from the current time. Selected flight is at {departure_time}."
cursor.execute(
"SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,)
)
current_flight = cursor.fetchone()
if not current_flight:
cursor.close()
conn.close()
return "No existing ticket found for the given ticket number."
# Check the signed-in user actually has this ticket
cursor.execute(
"SELECT * FROM tickets WHERE ticket_no = ? AND passenger_id = ?",
(ticket_no, passenger_id),
)
current_ticket = cursor.fetchone()
if not current_ticket:
cursor.close()
conn.close()
return f"Current signed-in passenger with ID {passenger_id} not the owner of ticket {ticket_no}"
# In a real application, you'd likely add additional checks here to enforce business logic,
# like "does the new departure airport match the current ticket", etc.
# While it's best to try to be *proactive* in 'type-hinting' policies to the LLM
# it's inevitably going to get things wrong, so you **also** need to ensure your
# API enforces valid behavior
cursor.execute(
"UPDATE ticket_flights SET flight_id = ? WHERE ticket_no = ?",
(new_flight_id, ticket_no),
)
conn.commit()
cursor.close()
conn.close()
return "Ticket successfully updated to new flight."
@tool
def cancel_ticket(ticket_no: str) -> str:
"""Cancel the user's ticket and remove it from the database."""
config = ensure_config()
configuration = config.get("configurable", {})
passenger_id = configuration.get("passenger_id", None)
if not passenger_id:
raise ValueError("No passenger ID configured.")
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute(
"SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,)
)
existing_ticket = cursor.fetchone()
if not existing_ticket:
cursor.close()
conn.close()
return "No existing ticket found for the given ticket number."
# Check the signed-in user actually has this ticket
cursor.execute(
"SELECT flight_id FROM tickets WHERE ticket_no = ? AND passenger_id = ?",
(ticket_no, passenger_id),
)
current_ticket = cursor.fetchone()
if not current_ticket:
cursor.close()
conn.close()
return f"Current signed-in passenger with ID {passenger_id} not the owner of ticket {ticket_no}"
cursor.execute("DELETE FROM ticket_flights WHERE ticket_no = ?", (ticket_no,))
conn.commit()
cursor.close()
conn.close()
return "Ticket successfully cancelled."
from datetime import date, datetime
from typing import Optional, Union
@tool
def search_car_rentals(
location: Optional[str] = None,
name: Optional[str] = None,
price_tier: Optional[str] = None,
start_date: Optional[Union[datetime, date]] = None,
end_date: Optional[Union[datetime, date]] = None,
) -> list[dict]:
"""
Search for car rentals based on location, name, price tier, start date, and end date.
Args:
location (Optional[str]): The location of the car rental. Defaults to None.
name (Optional[str]): The name of the car rental company. Defaults to None.
price_tier (Optional[str]): The price tier of the car rental. Defaults to None.
start_date (Optional[Union[datetime, date]]): The start date of the car rental. Defaults to None.
end_date (Optional[Union[datetime, date]]): The end date of the car rental. Defaults to None.
Returns:
list[dict]: A list of car rental dictionaries matching the search criteria.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
query = "SELECT * FROM car_rentals WHERE 1=1"
params = []
if location:
query += " AND location LIKE ?"
params.append(f"%{location}%")
if name:
query += " AND name LIKE ?"
params.append(f"%{name}%")
# For our tutorial, we will let you match on any dates and price tier.
# (since our toy dataset doesn't have much data)
cursor.execute(query, params)
results = cursor.fetchall()
conn.close()
return [
dict(zip([column[0] for column in cursor.description], row)) for row in results
]
@tool
def book_car_rental(rental_id: int) -> str:
"""
Book a car rental by its ID.
Args:
rental_id (int): The ID of the car rental to book.
Returns:
str: A message indicating whether the car rental was successfully booked or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute("UPDATE car_rentals SET booked = 1 WHERE id = ?", (rental_id,))
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Car rental {rental_id} successfully booked."
else:
conn.close()
return f"No car rental found with ID {rental_id}."
@tool
def update_car_rental(
rental_id: int,
start_date: Optional[Union[datetime, date]] = None,
end_date: Optional[Union[datetime, date]] = None,
) -> str:
"""
Update a car rental's start and end dates by its ID.
Args:
rental_id (int): The ID of the car rental to update.
start_date (Optional[Union[datetime, date]]): The new start date of the car rental. Defaults to None.
end_date (Optional[Union[datetime, date]]): The new end date of the car rental. Defaults to None.
Returns:
str: A message indicating whether the car rental was successfully updated or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
if start_date:
cursor.execute(
"UPDATE car_rentals SET start_date = ? WHERE id = ?",
(start_date, rental_id),
)
if end_date:
cursor.execute(
"UPDATE car_rentals SET end_date = ? WHERE id = ?", (end_date, rental_id)
)
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Car rental {rental_id} successfully updated."
else:
conn.close()
return f"No car rental found with ID {rental_id}."
@tool
def cancel_car_rental(rental_id: int) -> str:
"""
Cancel a car rental by its ID.
Args:
rental_id (int): The ID of the car rental to cancel.
Returns:
str: A message indicating whether the car rental was successfully cancelled or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute("UPDATE car_rentals SET booked = 0 WHERE id = ?", (rental_id,))
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Car rental {rental_id} successfully cancelled."
else:
conn.close()
return f"No car rental found with ID {rental_id}."
@tool
def search_hotels(
location: Optional[str] = None,
name: Optional[str] = None,
price_tier: Optional[str] = None,
checkin_date: Optional[Union[datetime, date]] = None,
checkout_date: Optional[Union[datetime, date]] = None,
) -> list[dict]:
"""
Search for hotels based on location, name, price tier, check-in date, and check-out date.
Args:
location (Optional[str]): The location of the hotel. Defaults to None.
name (Optional[str]): The name of the hotel. Defaults to None.
price_tier (Optional[str]): The price tier of the hotel. Defaults to None. Examples: Midscale, Upper Midscale, Upscale, Luxury
checkin_date (Optional[Union[datetime, date]]): The check-in date of the hotel. Defaults to None.
checkout_date (Optional[Union[datetime, date]]): The check-out date of the hotel. Defaults to None.
Returns:
list[dict]: A list of hotel dictionaries matching the search criteria.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
query = "SELECT * FROM hotels WHERE 1=1"
params = []
if location:
query += " AND location LIKE ?"
params.append(f"%{location}%")
if name:
query += " AND name LIKE ?"
params.append(f"%{name}%")
# For the sake of this tutorial, we will let you match on any dates and price tier.
cursor.execute(query, params)
results = cursor.fetchall()
conn.close()
return [
dict(zip([column[0] for column in cursor.description], row)) for row in results
]
@tool
def book_hotel(hotel_id: int) -> str:
"""
Book a hotel by its ID.
Args:
hotel_id (int): The ID of the hotel to book.
Returns:
str: A message indicating whether the hotel was successfully booked or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute("UPDATE hotels SET booked = 1 WHERE id = ?", (hotel_id,))
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Hotel {hotel_id} successfully booked."
else:
conn.close()
return f"No hotel found with ID {hotel_id}."
@tool
def update_hotel(
hotel_id: int,
checkin_date: Optional[Union[datetime, date]] = None,
checkout_date: Optional[Union[datetime, date]] = None,
) -> str:
"""
Update a hotel's check-in and check-out dates by its ID.
Args:
hotel_id (int): The ID of the hotel to update.
checkin_date (Optional[Union[datetime, date]]): The new check-in date of the hotel. Defaults to None.
checkout_date (Optional[Union[datetime, date]]): The new check-out date of the hotel. Defaults to None.
Returns:
str: A message indicating whether the hotel was successfully updated or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
if checkin_date:
cursor.execute(
"UPDATE hotels SET checkin_date = ? WHERE id = ?", (checkin_date, hotel_id)
)
if checkout_date:
cursor.execute(
"UPDATE hotels SET checkout_date = ? WHERE id = ?",
(checkout_date, hotel_id),
)
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Hotel {hotel_id} successfully updated."
else:
conn.close()
return f"No hotel found with ID {hotel_id}."
@tool
def cancel_hotel(hotel_id: int) -> str:
"""
Cancel a hotel by its ID.
Args:
hotel_id (int): The ID of the hotel to cancel.
Returns:
str: A message indicating whether the hotel was successfully cancelled or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute("UPDATE hotels SET booked = 0 WHERE id = ?", (hotel_id,))
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Hotel {hotel_id} successfully cancelled."
else:
conn.close()
return f"No hotel found with ID {hotel_id}."
@tool
def search_trip_recommendations(
location: Optional[str] = None,
name: Optional[str] = None,
keywords: Optional[str] = None,
) -> list[dict]:
"""
Search for trip recommendations based on location, name, and keywords.
Args:
location (Optional[str]): The location of the trip recommendation. Defaults to None.
name (Optional[str]): The name of the trip recommendation. Defaults to None.
keywords (Optional[str]): The keywords associated with the trip recommendation. Defaults to None.
Returns:
list[dict]: A list of trip recommendation dictionaries matching the search criteria.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
query = "SELECT * FROM trip_recommendations WHERE 1=1"
params = []
if location:
query += " AND location LIKE ?"
params.append(f"%{location}%")
if name:
query += " AND name LIKE ?"
params.append(f"%{name}%")
if keywords:
keyword_list = keywords.split(",")
keyword_conditions = " OR ".join(["keywords LIKE ?" for _ in keyword_list])
query += f" AND ({keyword_conditions})"
params.extend([f"%{keyword.strip()}%" for keyword in keyword_list])
cursor.execute(query, params)
results = cursor.fetchall()
conn.close()
return [
dict(zip([column[0] for column in cursor.description], row)) for row in results
]
@tool
def book_excursion(recommendation_id: int) -> str:
"""
Book a excursion by its recommendation ID.
Args:
recommendation_id (int): The ID of the trip recommendation to book.
Returns:
str: A message indicating whether the trip recommendation was successfully booked or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute(
"UPDATE trip_recommendations SET booked = 1 WHERE id = ?", (recommendation_id,)
)
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Trip recommendation {recommendation_id} successfully booked."
else:
conn.close()
return f"No trip recommendation found with ID {recommendation_id}."
@tool
def update_excursion(recommendation_id: int, details: str) -> str:
"""
Update a trip recommendation's details by its ID.
Args:
recommendation_id (int): The ID of the trip recommendation to update.
details (str): The new details of the trip recommendation.
Returns:
str: A message indicating whether the trip recommendation was successfully updated or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute(
"UPDATE trip_recommendations SET details = ? WHERE id = ?",
(details, recommendation_id),
)
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Trip recommendation {recommendation_id} successfully updated."
else:
conn.close()
return f"No trip recommendation found with ID {recommendation_id}."
@tool
def cancel_excursion(recommendation_id: int) -> str:
"""
Cancel a trip recommendation by its ID.
Args:
recommendation_id (int): The ID of the trip recommendation to cancel.
Returns:
str: A message indicating whether the trip recommendation was successfully cancelled or not.
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
cursor.execute(
"UPDATE trip_recommendations SET booked = 0 WHERE id = ?", (recommendation_id,)
)
conn.commit()
if cursor.rowcount > 0:
conn.close()
return f"Trip recommendation {recommendation_id} successfully cancelled."
else:
conn.close()
return f"No trip recommendation found with ID {recommendation_id}."
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolNode
def handle_tool_error(state) -> dict:
error = state.get("error")
tool_calls = state["messages"][-1].tool_calls
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_calls
]
}
def create_tool_node_with_fallback(tools: list) -> dict:
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def _print_event(event: dict, _printed: set, max_length=1500):
current_state = event.get("dialog_state")
if current_state:
print(f"Currently in: ", current_state[-1])
message = event.get("messages")
if message:
if isinstance(message, list):
message = message[-1]
if message.id not in _printed:
msg_repr = message.pretty_repr(html=True)
if len(msg_repr) > max_length:
msg_repr = msg_repr[:max_length] + " ... (truncated)"
print(msg_repr)
_printed.add(message.id)
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
from langchain_experimental.llms.ollama_functions import OllamaFunctions
from langchain_community.tools import BaseTool, DuckDuckGoSearchRun
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.pydantic_v1 import BaseModel
from typing import Any, Dict
class Assistant:
def __init__(self, runnable: Runnable):
self.runnable = runnable
def __call__(self, state: State, config: RunnableConfig):
cnt = 0
while True:
passenger_id = config.get("passenger_id", None)
state = {**state, "user_info": passenger_id}
try:
result = self.runnable.invoke(state)
cnt = 0
# If the LLM happens to return an empty response, we will re-prompt it
# for an actual response.
if not result.tool_calls and (
not result.content
or isinstance(result.content, list)
and not result.content[0].get("text")
):
messages = state["messages"] + [("user", "Respond with a real output.")]
state = {**state, "messages": messages}
else:
break
except Exception as ex:
print("exception", ex)
cnt += 1
if cnt > 5:
print("too many exceptions, breaking")
# raise ex
result = "Unable to provide an response"
break
return {"messages": result}
# Haiku is faster and cheaper, but less accurate
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
llm = OllamaFunctions(model="llama3", format="json", temperature=1, base_url="http://aurora:11434")
# You could swap LLMs, though you will likely want to update the prompts when
# doing so!
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4-turbo-preview")
primary_assistant_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful customer support assistant for Swiss Airlines. "
" Use the provided tools to search for flights, company policies, and other information to assist the user's queries. "
" When searching, be persistent. Expand your query bounds if the first search returns no results. "
" If a search comes up empty, expand your search before giving up."
"\n\nCurrent user:\n<User>\n{user_info}\n</User>"
"\nCurrent time: {time}.",
),
("placeholder", "{messages}"),
]
).partial(time=datetime.now())
part_1_tools = [
DuckDuckGoSearchRun(max_results=1),
fetch_user_flight_information,
search_flights,
lookup_policy,
update_ticket_to_new_flight,
cancel_ticket,
# search_car_rentals,
# book_car_rental,
# update_car_rental,
# cancel_car_rental,
# search_hotels,
# book_hotel,
# update_hotel,
# cancel_hotel,
# search_trip_recommendations,
# book_excursion,
# update_excursion,
# cancel_excursion,
]
part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools=part_1_tools)
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
builder = StateGraph(State)
# Define nodes: these do the work
builder.add_node("assistant", Assistant(part_1_assistant_runnable))
builder.add_node("tools", create_tool_node_with_fallback(part_1_tools))
# Define edges: these determine how the control flow moves
builder.set_entry_point("assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# The checkpointer lets the graph persist its state
# this is a complete memory for the entire graph.
memory = SqliteSaver.from_conn_string(":memory:")
part_1_graph = builder.compile(checkpointer=memory)
import shutil
import uuid
# Let's create an example conversation a user might have with the assistant
tutorial_questions = [
"Hi there, what time is my flight?",
"Am i allowed to update my flight to something sooner? I want to leave later today.",
"Update my flight to sometime next week then",
"The next available option is great",
"what about lodging and transportation?",
"Yeah i think i'd like an affordable hotel for my week-long stay (7 days). And I'll want to rent a car.",
"OK could you place a reservation for your recommended hotel? It sounds nice.",
"yes go ahead and book anything that's moderate expense and has availability.",
"Now for a car, what are my options?",
"Awesome let's just get the cheapest option. Go ahead and book for 7 days",
"Cool so now what recommendations do you have on excursions?",
"Are they available while I'm there?",
"interesting - i like the museums, what options are there? ",
"OK great pick one and book it for my second day there.",
]
# Update with the backup file so we can restart from the original place in each section
shutil.copy(backup_file, db)
thread_id = str(uuid.uuid4())
config = {
"configurable": {
# The passenger_id is used in our flight tools to
# fetch the user's flight information
"passenger_id": "3442 587242",
# Checkpoints are accessed by thread_id
"thread_id": thread_id,
}
}
_printed = set()
for question in tutorial_questions:
events = part_1_graph.stream(
{"messages": ("user", question)}, config, stream_mode="values"
)
for event in events:
_print_event(event, _printed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment