Created
June 2, 2024 06:28
-
-
Save lalanikarim/99fd6b1db3346a744246d6a2cab72c04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
import sqlite3 | |
import pandas as pd | |
import requests | |
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite" | |
local_file = "travel2.sqlite" | |
# The backup lets us restart for each tutorial section | |
backup_file = "travel2.backup.sqlite" | |
overwrite = False | |
if overwrite or not os.path.exists(local_file): | |
response = requests.get(db_url) | |
response.raise_for_status() # Ensure the request was successful | |
with open(local_file, "wb") as f: | |
f.write(response.content) | |
# Backup - we will use this to "reset" our DB in each section | |
shutil.copy(local_file, backup_file) | |
# Convert the flights to present time for our tutorial | |
conn = sqlite3.connect(local_file) | |
cursor = conn.cursor() | |
tables = pd.read_sql( | |
"SELECT name FROM sqlite_master WHERE type='table';", conn | |
).name.tolist() | |
tdf = {} | |
for t in tables: | |
tdf[t] = pd.read_sql(f"SELECT * from {t}", conn) | |
example_time = pd.to_datetime( | |
tdf["flights"]["actual_departure"].replace("\\N", pd.NaT) | |
).max() | |
current_time = pd.to_datetime("now").tz_localize(example_time.tz) | |
time_diff = current_time - example_time | |
tdf["bookings"]["book_date"] = ( | |
pd.to_datetime(tdf["bookings"]["book_date"].replace("\\N", pd.NaT), utc=True) | |
+ time_diff | |
) | |
datetime_columns = [ | |
"scheduled_departure", | |
"scheduled_arrival", | |
"actual_departure", | |
"actual_arrival", | |
] | |
for column in datetime_columns: | |
tdf["flights"][column] = ( | |
pd.to_datetime(tdf["flights"][column].replace("\\N", pd.NaT)) + time_diff | |
) | |
for table_name, df in tdf.items(): | |
df.to_sql(table_name, conn, if_exists="replace", index=False) | |
del df | |
del tdf | |
conn.commit() | |
conn.close() | |
db = local_file # We'll be using this local file as our DB in this tutorial | |
import re | |
import numpy as np | |
from langchain_community.embeddings import OllamaEmbeddings | |
from langchain_core.tools import tool | |
response = requests.get( | |
"https://storage.googleapis.com/benchmarks-artifacts/travel-db/swiss_faq.md" | |
) | |
response.raise_for_status() | |
faq_text = response.text | |
docs = [{"page_content": txt} for txt in re.split(r"(?=\n##)", faq_text)] | |
class VectorStoreRetriever: | |
def __init__(self, docs: list, vectors: list, embed_client): | |
self._arr = np.array(vectors) | |
self._docs = docs | |
self._client = embed_client | |
@classmethod | |
def from_docs(cls, docs, embed_client): | |
embeddings = embed_client.embed_documents( | |
[doc["page_content"] for doc in docs] | |
) | |
#vectors = [emb.embedding for emb in embeddings.data] | |
return cls(docs, embeddings, embed_client) | |
def query(self, query: str, k: int = 5) -> list[dict]: | |
embed = self._client.embed_query( | |
query | |
) | |
# "@" is just a matrix multiplication in python | |
scores = np.array(embed) @ self._arr.T | |
top_k_idx = np.argpartition(scores, -k)[-k:] | |
top_k_idx_sorted = top_k_idx[np.argsort(-scores[top_k_idx])] | |
return [ | |
{**self._docs[idx], "similarity": scores[idx]} for idx in top_k_idx_sorted | |
] | |
embed = OllamaEmbeddings(model="nomic-embed-text") | |
retriever = VectorStoreRetriever.from_docs(docs, embed) | |
@tool | |
def lookup_policy(query: str) -> str: | |
"""Consult the company policies to check whether certain options are permitted. | |
Use this before making any flight changes performing other 'write' events.""" | |
docs = retriever.query(query, k=2) | |
return "\n\n".join([doc["page_content"] for doc in docs]) | |
import sqlite3 | |
from datetime import date, datetime | |
from typing import Optional | |
import pytz | |
from langchain_core.runnables import ensure_config | |
@tool | |
def fetch_user_flight_information() -> list[dict]: | |
"""Fetch all tickets for the user along with corresponding flight information and seat assignments. | |
Returns: | |
A list of dictionaries where each dictionary contains the ticket details, | |
associated flight details, and the seat assignments for each ticket belonging to the user. | |
""" | |
config = ensure_config() # Fetch from the context | |
configuration = config.get("configurable", {}) | |
passenger_id = configuration.get("passenger_id", None) | |
if not passenger_id: | |
raise ValueError("No passenger ID configured.") | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
query = """ | |
SELECT | |
t.ticket_no, t.book_ref, | |
f.flight_id, f.flight_no, f.departure_airport, f.arrival_airport, f.scheduled_departure, f.scheduled_arrival, | |
bp.seat_no, tf.fare_conditions | |
FROM | |
tickets t | |
JOIN ticket_flights tf ON t.ticket_no = tf.ticket_no | |
JOIN flights f ON tf.flight_id = f.flight_id | |
JOIN boarding_passes bp ON bp.ticket_no = t.ticket_no AND bp.flight_id = f.flight_id | |
WHERE | |
t.passenger_id = ? | |
""" | |
cursor.execute(query, (passenger_id,)) | |
rows = cursor.fetchall() | |
column_names = [column[0] for column in cursor.description] | |
results = [dict(zip(column_names, row)) for row in rows] | |
cursor.close() | |
conn.close() | |
return results | |
@tool | |
def search_flights( | |
departure_airport: Optional[str] = None, | |
arrival_airport: Optional[str] = None, | |
start_time: Optional[date | datetime] = None, | |
end_time: Optional[date | datetime] = None, | |
limit: int = 20, | |
) -> list[dict]: | |
"""Search for flights based on departure airport, arrival airport, and departure time range.""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
query = "SELECT * FROM flights WHERE 1 = 1" | |
params = [] | |
if departure_airport: | |
query += " AND departure_airport = ?" | |
params.append(departure_airport) | |
if arrival_airport: | |
query += " AND arrival_airport = ?" | |
params.append(arrival_airport) | |
if start_time: | |
query += " AND scheduled_departure >= ?" | |
params.append(start_time) | |
if end_time: | |
query += " AND scheduled_departure <= ?" | |
params.append(end_time) | |
query += " LIMIT ?" | |
params.append(limit) | |
cursor.execute(query, params) | |
rows = cursor.fetchall() | |
column_names = [column[0] for column in cursor.description] | |
results = [dict(zip(column_names, row)) for row in rows] | |
cursor.close() | |
conn.close() | |
return results | |
@tool | |
def update_ticket_to_new_flight(ticket_no: str, new_flight_id: int) -> str: | |
"""Update the user's ticket to a new valid flight.""" | |
config = ensure_config() | |
configuration = config.get("configurable", {}) | |
passenger_id = configuration.get("passenger_id", None) | |
if not passenger_id: | |
raise ValueError("No passenger ID configured.") | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute( | |
"SELECT departure_airport, arrival_airport, scheduled_departure FROM flights WHERE flight_id = ?", | |
(new_flight_id,), | |
) | |
new_flight = cursor.fetchone() | |
if not new_flight: | |
cursor.close() | |
conn.close() | |
return "Invalid new flight ID provided." | |
column_names = [column[0] for column in cursor.description] | |
new_flight_dict = dict(zip(column_names, new_flight)) | |
timezone = pytz.timezone("Etc/GMT-3") | |
current_time = datetime.now(tz=timezone) | |
departure_time = datetime.strptime( | |
new_flight_dict["scheduled_departure"], "%Y-%m-%d %H:%M:%S.%f%z" | |
) | |
time_until = (departure_time - current_time).total_seconds() | |
if time_until < (3 * 3600): | |
return f"Not permitted to reschedule to a flight that is less than 3 hours from the current time. Selected flight is at {departure_time}." | |
cursor.execute( | |
"SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,) | |
) | |
current_flight = cursor.fetchone() | |
if not current_flight: | |
cursor.close() | |
conn.close() | |
return "No existing ticket found for the given ticket number." | |
# Check the signed-in user actually has this ticket | |
cursor.execute( | |
"SELECT * FROM tickets WHERE ticket_no = ? AND passenger_id = ?", | |
(ticket_no, passenger_id), | |
) | |
current_ticket = cursor.fetchone() | |
if not current_ticket: | |
cursor.close() | |
conn.close() | |
return f"Current signed-in passenger with ID {passenger_id} not the owner of ticket {ticket_no}" | |
# In a real application, you'd likely add additional checks here to enforce business logic, | |
# like "does the new departure airport match the current ticket", etc. | |
# While it's best to try to be *proactive* in 'type-hinting' policies to the LLM | |
# it's inevitably going to get things wrong, so you **also** need to ensure your | |
# API enforces valid behavior | |
cursor.execute( | |
"UPDATE ticket_flights SET flight_id = ? WHERE ticket_no = ?", | |
(new_flight_id, ticket_no), | |
) | |
conn.commit() | |
cursor.close() | |
conn.close() | |
return "Ticket successfully updated to new flight." | |
@tool | |
def cancel_ticket(ticket_no: str) -> str: | |
"""Cancel the user's ticket and remove it from the database.""" | |
config = ensure_config() | |
configuration = config.get("configurable", {}) | |
passenger_id = configuration.get("passenger_id", None) | |
if not passenger_id: | |
raise ValueError("No passenger ID configured.") | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute( | |
"SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,) | |
) | |
existing_ticket = cursor.fetchone() | |
if not existing_ticket: | |
cursor.close() | |
conn.close() | |
return "No existing ticket found for the given ticket number." | |
# Check the signed-in user actually has this ticket | |
cursor.execute( | |
"SELECT flight_id FROM tickets WHERE ticket_no = ? AND passenger_id = ?", | |
(ticket_no, passenger_id), | |
) | |
current_ticket = cursor.fetchone() | |
if not current_ticket: | |
cursor.close() | |
conn.close() | |
return f"Current signed-in passenger with ID {passenger_id} not the owner of ticket {ticket_no}" | |
cursor.execute("DELETE FROM ticket_flights WHERE ticket_no = ?", (ticket_no,)) | |
conn.commit() | |
cursor.close() | |
conn.close() | |
return "Ticket successfully cancelled." | |
from datetime import date, datetime | |
from typing import Optional, Union | |
@tool | |
def search_car_rentals( | |
location: Optional[str] = None, | |
name: Optional[str] = None, | |
price_tier: Optional[str] = None, | |
start_date: Optional[Union[datetime, date]] = None, | |
end_date: Optional[Union[datetime, date]] = None, | |
) -> list[dict]: | |
""" | |
Search for car rentals based on location, name, price tier, start date, and end date. | |
Args: | |
location (Optional[str]): The location of the car rental. Defaults to None. | |
name (Optional[str]): The name of the car rental company. Defaults to None. | |
price_tier (Optional[str]): The price tier of the car rental. Defaults to None. | |
start_date (Optional[Union[datetime, date]]): The start date of the car rental. Defaults to None. | |
end_date (Optional[Union[datetime, date]]): The end date of the car rental. Defaults to None. | |
Returns: | |
list[dict]: A list of car rental dictionaries matching the search criteria. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
query = "SELECT * FROM car_rentals WHERE 1=1" | |
params = [] | |
if location: | |
query += " AND location LIKE ?" | |
params.append(f"%{location}%") | |
if name: | |
query += " AND name LIKE ?" | |
params.append(f"%{name}%") | |
# For our tutorial, we will let you match on any dates and price tier. | |
# (since our toy dataset doesn't have much data) | |
cursor.execute(query, params) | |
results = cursor.fetchall() | |
conn.close() | |
return [ | |
dict(zip([column[0] for column in cursor.description], row)) for row in results | |
] | |
@tool | |
def book_car_rental(rental_id: int) -> str: | |
""" | |
Book a car rental by its ID. | |
Args: | |
rental_id (int): The ID of the car rental to book. | |
Returns: | |
str: A message indicating whether the car rental was successfully booked or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute("UPDATE car_rentals SET booked = 1 WHERE id = ?", (rental_id,)) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Car rental {rental_id} successfully booked." | |
else: | |
conn.close() | |
return f"No car rental found with ID {rental_id}." | |
@tool | |
def update_car_rental( | |
rental_id: int, | |
start_date: Optional[Union[datetime, date]] = None, | |
end_date: Optional[Union[datetime, date]] = None, | |
) -> str: | |
""" | |
Update a car rental's start and end dates by its ID. | |
Args: | |
rental_id (int): The ID of the car rental to update. | |
start_date (Optional[Union[datetime, date]]): The new start date of the car rental. Defaults to None. | |
end_date (Optional[Union[datetime, date]]): The new end date of the car rental. Defaults to None. | |
Returns: | |
str: A message indicating whether the car rental was successfully updated or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
if start_date: | |
cursor.execute( | |
"UPDATE car_rentals SET start_date = ? WHERE id = ?", | |
(start_date, rental_id), | |
) | |
if end_date: | |
cursor.execute( | |
"UPDATE car_rentals SET end_date = ? WHERE id = ?", (end_date, rental_id) | |
) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Car rental {rental_id} successfully updated." | |
else: | |
conn.close() | |
return f"No car rental found with ID {rental_id}." | |
@tool | |
def cancel_car_rental(rental_id: int) -> str: | |
""" | |
Cancel a car rental by its ID. | |
Args: | |
rental_id (int): The ID of the car rental to cancel. | |
Returns: | |
str: A message indicating whether the car rental was successfully cancelled or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute("UPDATE car_rentals SET booked = 0 WHERE id = ?", (rental_id,)) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Car rental {rental_id} successfully cancelled." | |
else: | |
conn.close() | |
return f"No car rental found with ID {rental_id}." | |
@tool | |
def search_hotels( | |
location: Optional[str] = None, | |
name: Optional[str] = None, | |
price_tier: Optional[str] = None, | |
checkin_date: Optional[Union[datetime, date]] = None, | |
checkout_date: Optional[Union[datetime, date]] = None, | |
) -> list[dict]: | |
""" | |
Search for hotels based on location, name, price tier, check-in date, and check-out date. | |
Args: | |
location (Optional[str]): The location of the hotel. Defaults to None. | |
name (Optional[str]): The name of the hotel. Defaults to None. | |
price_tier (Optional[str]): The price tier of the hotel. Defaults to None. Examples: Midscale, Upper Midscale, Upscale, Luxury | |
checkin_date (Optional[Union[datetime, date]]): The check-in date of the hotel. Defaults to None. | |
checkout_date (Optional[Union[datetime, date]]): The check-out date of the hotel. Defaults to None. | |
Returns: | |
list[dict]: A list of hotel dictionaries matching the search criteria. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
query = "SELECT * FROM hotels WHERE 1=1" | |
params = [] | |
if location: | |
query += " AND location LIKE ?" | |
params.append(f"%{location}%") | |
if name: | |
query += " AND name LIKE ?" | |
params.append(f"%{name}%") | |
# For the sake of this tutorial, we will let you match on any dates and price tier. | |
cursor.execute(query, params) | |
results = cursor.fetchall() | |
conn.close() | |
return [ | |
dict(zip([column[0] for column in cursor.description], row)) for row in results | |
] | |
@tool | |
def book_hotel(hotel_id: int) -> str: | |
""" | |
Book a hotel by its ID. | |
Args: | |
hotel_id (int): The ID of the hotel to book. | |
Returns: | |
str: A message indicating whether the hotel was successfully booked or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute("UPDATE hotels SET booked = 1 WHERE id = ?", (hotel_id,)) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Hotel {hotel_id} successfully booked." | |
else: | |
conn.close() | |
return f"No hotel found with ID {hotel_id}." | |
@tool | |
def update_hotel( | |
hotel_id: int, | |
checkin_date: Optional[Union[datetime, date]] = None, | |
checkout_date: Optional[Union[datetime, date]] = None, | |
) -> str: | |
""" | |
Update a hotel's check-in and check-out dates by its ID. | |
Args: | |
hotel_id (int): The ID of the hotel to update. | |
checkin_date (Optional[Union[datetime, date]]): The new check-in date of the hotel. Defaults to None. | |
checkout_date (Optional[Union[datetime, date]]): The new check-out date of the hotel. Defaults to None. | |
Returns: | |
str: A message indicating whether the hotel was successfully updated or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
if checkin_date: | |
cursor.execute( | |
"UPDATE hotels SET checkin_date = ? WHERE id = ?", (checkin_date, hotel_id) | |
) | |
if checkout_date: | |
cursor.execute( | |
"UPDATE hotels SET checkout_date = ? WHERE id = ?", | |
(checkout_date, hotel_id), | |
) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Hotel {hotel_id} successfully updated." | |
else: | |
conn.close() | |
return f"No hotel found with ID {hotel_id}." | |
@tool | |
def cancel_hotel(hotel_id: int) -> str: | |
""" | |
Cancel a hotel by its ID. | |
Args: | |
hotel_id (int): The ID of the hotel to cancel. | |
Returns: | |
str: A message indicating whether the hotel was successfully cancelled or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute("UPDATE hotels SET booked = 0 WHERE id = ?", (hotel_id,)) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Hotel {hotel_id} successfully cancelled." | |
else: | |
conn.close() | |
return f"No hotel found with ID {hotel_id}." | |
@tool | |
def search_trip_recommendations( | |
location: Optional[str] = None, | |
name: Optional[str] = None, | |
keywords: Optional[str] = None, | |
) -> list[dict]: | |
""" | |
Search for trip recommendations based on location, name, and keywords. | |
Args: | |
location (Optional[str]): The location of the trip recommendation. Defaults to None. | |
name (Optional[str]): The name of the trip recommendation. Defaults to None. | |
keywords (Optional[str]): The keywords associated with the trip recommendation. Defaults to None. | |
Returns: | |
list[dict]: A list of trip recommendation dictionaries matching the search criteria. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
query = "SELECT * FROM trip_recommendations WHERE 1=1" | |
params = [] | |
if location: | |
query += " AND location LIKE ?" | |
params.append(f"%{location}%") | |
if name: | |
query += " AND name LIKE ?" | |
params.append(f"%{name}%") | |
if keywords: | |
keyword_list = keywords.split(",") | |
keyword_conditions = " OR ".join(["keywords LIKE ?" for _ in keyword_list]) | |
query += f" AND ({keyword_conditions})" | |
params.extend([f"%{keyword.strip()}%" for keyword in keyword_list]) | |
cursor.execute(query, params) | |
results = cursor.fetchall() | |
conn.close() | |
return [ | |
dict(zip([column[0] for column in cursor.description], row)) for row in results | |
] | |
@tool | |
def book_excursion(recommendation_id: int) -> str: | |
""" | |
Book a excursion by its recommendation ID. | |
Args: | |
recommendation_id (int): The ID of the trip recommendation to book. | |
Returns: | |
str: A message indicating whether the trip recommendation was successfully booked or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute( | |
"UPDATE trip_recommendations SET booked = 1 WHERE id = ?", (recommendation_id,) | |
) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Trip recommendation {recommendation_id} successfully booked." | |
else: | |
conn.close() | |
return f"No trip recommendation found with ID {recommendation_id}." | |
@tool | |
def update_excursion(recommendation_id: int, details: str) -> str: | |
""" | |
Update a trip recommendation's details by its ID. | |
Args: | |
recommendation_id (int): The ID of the trip recommendation to update. | |
details (str): The new details of the trip recommendation. | |
Returns: | |
str: A message indicating whether the trip recommendation was successfully updated or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute( | |
"UPDATE trip_recommendations SET details = ? WHERE id = ?", | |
(details, recommendation_id), | |
) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Trip recommendation {recommendation_id} successfully updated." | |
else: | |
conn.close() | |
return f"No trip recommendation found with ID {recommendation_id}." | |
@tool | |
def cancel_excursion(recommendation_id: int) -> str: | |
""" | |
Cancel a trip recommendation by its ID. | |
Args: | |
recommendation_id (int): The ID of the trip recommendation to cancel. | |
Returns: | |
str: A message indicating whether the trip recommendation was successfully cancelled or not. | |
""" | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
cursor.execute( | |
"UPDATE trip_recommendations SET booked = 0 WHERE id = ?", (recommendation_id,) | |
) | |
conn.commit() | |
if cursor.rowcount > 0: | |
conn.close() | |
return f"Trip recommendation {recommendation_id} successfully cancelled." | |
else: | |
conn.close() | |
return f"No trip recommendation found with ID {recommendation_id}." | |
from langchain_core.runnables import RunnableLambda | |
from langchain_core.messages import ToolMessage | |
from langgraph.prebuilt import ToolNode | |
def handle_tool_error(state) -> dict: | |
error = state.get("error") | |
tool_calls = state["messages"][-1].tool_calls | |
return { | |
"messages": [ | |
ToolMessage( | |
content=f"Error: {repr(error)}\n please fix your mistakes.", | |
tool_call_id=tc["id"], | |
) | |
for tc in tool_calls | |
] | |
} | |
def create_tool_node_with_fallback(tools: list) -> dict: | |
return ToolNode(tools).with_fallbacks( | |
[RunnableLambda(handle_tool_error)], exception_key="error" | |
) | |
def _print_event(event: dict, _printed: set, max_length=1500): | |
current_state = event.get("dialog_state") | |
if current_state: | |
print(f"Currently in: ", current_state[-1]) | |
message = event.get("messages") | |
if message: | |
if isinstance(message, list): | |
message = message[-1] | |
if message.id not in _printed: | |
msg_repr = message.pretty_repr(html=True) | |
if len(msg_repr) > max_length: | |
msg_repr = msg_repr[:max_length] + " ... (truncated)" | |
print(msg_repr) | |
_printed.add(message.id) | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langgraph.graph.message import AnyMessage, add_messages | |
class State(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
from langchain_experimental.llms.ollama_functions import OllamaFunctions | |
from langchain_community.tools import BaseTool, DuckDuckGoSearchRun | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import Runnable, RunnableConfig | |
from langchain_core.pydantic_v1 import BaseModel | |
from typing import Any, Dict | |
class Assistant: | |
def __init__(self, runnable: Runnable): | |
self.runnable = runnable | |
def __call__(self, state: State, config: RunnableConfig): | |
cnt = 0 | |
while True: | |
passenger_id = config.get("passenger_id", None) | |
state = {**state, "user_info": passenger_id} | |
try: | |
result = self.runnable.invoke(state) | |
cnt = 0 | |
# If the LLM happens to return an empty response, we will re-prompt it | |
# for an actual response. | |
if not result.tool_calls and ( | |
not result.content | |
or isinstance(result.content, list) | |
and not result.content[0].get("text") | |
): | |
messages = state["messages"] + [("user", "Respond with a real output.")] | |
state = {**state, "messages": messages} | |
else: | |
break | |
except Exception as ex: | |
print("exception", ex) | |
cnt += 1 | |
if cnt > 5: | |
print("too many exceptions, breaking") | |
# raise ex | |
result = "Unable to provide an response" | |
break | |
return {"messages": result} | |
# Haiku is faster and cheaper, but less accurate | |
# llm = ChatAnthropic(model="claude-3-haiku-20240307") | |
llm = OllamaFunctions(model="llama3", format="json", temperature=1, base_url="http://aurora:11434") | |
# You could swap LLMs, though you will likely want to update the prompts when | |
# doing so! | |
# from langchain_openai import ChatOpenAI | |
# llm = ChatOpenAI(model="gpt-4-turbo-preview") | |
primary_assistant_prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a helpful customer support assistant for Swiss Airlines. " | |
" Use the provided tools to search for flights, company policies, and other information to assist the user's queries. " | |
" When searching, be persistent. Expand your query bounds if the first search returns no results. " | |
" If a search comes up empty, expand your search before giving up." | |
"\n\nCurrent user:\n<User>\n{user_info}\n</User>" | |
"\nCurrent time: {time}.", | |
), | |
("placeholder", "{messages}"), | |
] | |
).partial(time=datetime.now()) | |
part_1_tools = [ | |
DuckDuckGoSearchRun(max_results=1), | |
fetch_user_flight_information, | |
search_flights, | |
lookup_policy, | |
update_ticket_to_new_flight, | |
cancel_ticket, | |
# search_car_rentals, | |
# book_car_rental, | |
# update_car_rental, | |
# cancel_car_rental, | |
# search_hotels, | |
# book_hotel, | |
# update_hotel, | |
# cancel_hotel, | |
# search_trip_recommendations, | |
# book_excursion, | |
# update_excursion, | |
# cancel_excursion, | |
] | |
part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools=part_1_tools) | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
builder = StateGraph(State) | |
# Define nodes: these do the work | |
builder.add_node("assistant", Assistant(part_1_assistant_runnable)) | |
builder.add_node("tools", create_tool_node_with_fallback(part_1_tools)) | |
# Define edges: these determine how the control flow moves | |
builder.set_entry_point("assistant") | |
builder.add_conditional_edges( | |
"assistant", | |
tools_condition, | |
) | |
builder.add_edge("tools", "assistant") | |
# The checkpointer lets the graph persist its state | |
# this is a complete memory for the entire graph. | |
memory = SqliteSaver.from_conn_string(":memory:") | |
part_1_graph = builder.compile(checkpointer=memory) | |
import shutil | |
import uuid | |
# Let's create an example conversation a user might have with the assistant | |
tutorial_questions = [ | |
"Hi there, what time is my flight?", | |
"Am i allowed to update my flight to something sooner? I want to leave later today.", | |
"Update my flight to sometime next week then", | |
"The next available option is great", | |
"what about lodging and transportation?", | |
"Yeah i think i'd like an affordable hotel for my week-long stay (7 days). And I'll want to rent a car.", | |
"OK could you place a reservation for your recommended hotel? It sounds nice.", | |
"yes go ahead and book anything that's moderate expense and has availability.", | |
"Now for a car, what are my options?", | |
"Awesome let's just get the cheapest option. Go ahead and book for 7 days", | |
"Cool so now what recommendations do you have on excursions?", | |
"Are they available while I'm there?", | |
"interesting - i like the museums, what options are there? ", | |
"OK great pick one and book it for my second day there.", | |
] | |
# Update with the backup file so we can restart from the original place in each section | |
shutil.copy(backup_file, db) | |
thread_id = str(uuid.uuid4()) | |
config = { | |
"configurable": { | |
# The passenger_id is used in our flight tools to | |
# fetch the user's flight information | |
"passenger_id": "3442 587242", | |
# Checkpoints are accessed by thread_id | |
"thread_id": thread_id, | |
} | |
} | |
_printed = set() | |
for question in tutorial_questions: | |
events = part_1_graph.stream( | |
{"messages": ("user", question)}, config, stream_mode="values" | |
) | |
for event in events: | |
_print_event(event, _printed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment