Last active February 9, 2024 07:42
"<a href=\"\" target=\"_parent\"><img src=\"\" alt=\"Open In Colab\"/></a>"
"! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python pypdf"
"import getpass\n",
"import os\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
"# Set your Tavily API key\n",
"os.environ[\"TAVILY_API_KEY\"] = getpass.getpass()"
"# Download and prepare SEC filing"
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n",
"sec_filing_pdf = \"\"\n",
"# Create your PDF loader\n",
"loader = PyPDFLoader(sec_filing_pdf)\n",
"# Load the PDF document\n",
"documents = loader.load()\n",
"# Chunk the financial report\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
"# Load the SEC filing into vector store"
"from langchain_community.vectorstores import Chroma\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"# Load the document into Chroma\n",
"embedding_function = OpenAIEmbeddings()\n",
"vectorstore = Chroma.from_documents(docs, embedding_function)\n",
"retriever = vectorstore.as_retriever()"
"# Define graph State"
"from typing import Dict, TypedDict\n",
"from langchain_core.messages import BaseMessage\n",
"class GraphState(TypedDict):\n",
" \"\"\"\n",
" Represents the state of our graph.\n",
" Attributes:\n",
" keys: A dictionary where each key is a string.\n",
" \"\"\"\n",
" keys: Dict[str, any]"
"# Define the graph's Nodes and Edges"
"import json\n",
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"from langchain import hub\n",
"from langchain.output_parsers.openai_tools import PydanticToolsParser\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema import Document\n",
"from import TavilySearchResults\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_core.messages import BaseMessage, FunctionMessage\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"### Nodes ###\n",
"def retrieve(state):\n",
" \"\"\"\n",
" Retrieve documents\n",
" Args:\n",
" state (dict): The current graph state\n",
" Returns:\n",
" state (dict): New key added to state, documents, that contains retrieved documents\n",
" \"\"\"\n",
" print(\"---RETRIEVE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = retriever.get_relevant_documents(question)\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"def generate(state):\n",
" \"\"\"\n",
" Generate answer\n",
" Args:\n",
" state (dict): The current graph state\n",
" Returns:\n",
" state (dict): New key added to state, generation, that contains LLM generation\n",
" \"\"\"\n",
" print(\"---GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" # Prompt\n",
" prompt = hub.pull(\"rlm/rag-prompt\")\n",
" # LLM\n",
" llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0, streaming=True)\n",
" # Post-processing\n",
" def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
" # Chain\n",
" rag_chain = prompt | llm | StrOutputParser()\n",
" # Run\n",
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
" return {\n",
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
" }\n",
"def grade_documents(state):\n",
" \"\"\"\n",
" Determines whether the retrieved documents are relevant to the question.\n",
" Args:\n",
" state (dict): The current graph state\n",
" Returns:\n",
" state (dict): Updates documents key with relevant documents\n",
" \"\"\"\n",
" print(\"---CHECK RELEVANCE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
" binary_score: str = Field(description=\"Relevance score 'yes' or 'no'\")\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n",
" Here is the retrieved document: \\n\\n {context} \\n\\n\n",
" Here is the user question: {question} \\n\n",
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n",
" input_variables=[\"context\", \"question\"],\n",
" )\n",
" # Chain\n",
" chain = prompt | llm_with_tool | parser_tool\n",
" # Score\n",
" filtered_docs = []\n",
" search = \"No\" # Default do not opt for web search to supplement retrieval\n",
" for d in documents:\n",
" score = chain.invoke({\"question\": question, \"context\": d.page_content})\n",
" grade = score[0].binary_score\n",
" if grade == \"yes\":\n",
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
" filtered_docs.append(d)\n",
" else:\n",
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
" search = \"Yes\" # Perform web search\n",
" continue\n",
" return {\n",
" \"keys\": {\n",
" \"documents\": filtered_docs,\n",
" \"question\": question,\n",
" \"run_web_search\": search,\n",
" }\n",
" }\n",
"def transform_query(state):\n",
" \"\"\"\n",
" Transform the query to produce a better question.\n",
" Args:\n",
" state (dict): The current graph state\n",
" Returns:\n",
" state (dict): Updates question key with a re-phrased question\n",
" \"\"\"\n",
" print(\"---TRANSFORM QUERY---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" # Create a prompt template with format instructions and the query\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are generating questions that is well optimized for retrieval. \\n\n",
" Look at the input and try to reason about the underlying sematic intent / meaning. \\n\n",
" Here is the initial question:\n",
" \\n ------- \\n\n",
" {question}\n",
" \\n ------- \\n\n",
" Formulate an improved question: \"\"\",\n",
" input_variables=[\"question\"],\n",
" )\n",
" # Grader\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
" # Prompt\n",
" chain = prompt | model | StrOutputParser()\n",
" better_question = chain.invoke({\"question\": question})\n",
" return {\"keys\": {\"documents\": documents, \"question\": better_question}}\n",
"def web_search(state):\n",
" \"\"\"\n",
" Web search based on the re-phrased question using Tavily API.\n",
" Args:\n",
" state (dict): The current graph state\n",
" Returns:\n",
" state (dict): Updates documents key with appended web results\n",
" \"\"\"\n",
" print(\"---WEB SEARCH---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" tool = TavilySearchResults()\n",
" docs = tool.invoke({\"query\": question})\n",
" web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
" web_results = Document(page_content=web_results)\n",
" documents.append(web_results)\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"### Edges\n",
"def decide_to_generate(state):\n",
" \"\"\"\n",
" Determines whether to generate an answer or re-generate a question for web search.\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
" Returns:\n",
" str: Next node to call\n",
" \"\"\"\n",
" print(\"---DECIDE TO GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" filtered_documents = state_dict[\"documents\"]\n",
" search = state_dict[\"run_web_search\"]\n",
" if search == \"Yes\":\n",
" # All documents have been filtered check_relevance\n",
" # We will re-generate a new query\n",
" return \"transform_query\"\n",
" else:\n",
" # We have relevant documents, so generate answer\n",
" print(\"---DECISION: GENERATE---\")\n",
" return \"generate\""
"# Build the graph"
"import pprint\n",
"from langgraph.graph import END, StateGraph\n",
"workflow = StateGraph(GraphState)\n",
"# Define the nodes\n",
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
"workflow.add_node(\"generate\", generate) # generatae\n",
"workflow.add_node(\"transform_query\", transform_query) # transform_query\n",
"workflow.add_node(\"web_search\", web_search) # web search\n",
"# Build graph\n",
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
" \"grade_documents\",\n",
" decide_to_generate,\n",
" {\n",
" \"transform_query\": \"transform_query\",\n",
" \"generate\": \"generate\",\n",
" },\n",
"workflow.add_edge(\"transform_query\", \"web_search\")\n",
"workflow.add_edge(\"web_search\", \"generate\")\n",
"workflow.add_edge(\"generate\", END)\n",
"# Compile\n",
"app = workflow.compile()"
"# Run the graph"
"# Run\n",
"question = \"What was Airbnb's revenue in Q3 2023?\"\n",
"inputs = {\"keys\": {\"question\": question}}\n",
"print(f\"Question: {question}\\n\")\n",
"for output in\n",
" for key, value in output.items():\n",
" # Print Node\n",
" print()\n",
"# Final generation\n",
"answer = value['keys']['generation']\n",
"print(f\"Answer: {answer}\")"
