Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active February 9, 2024 07:42
Show Gist options
  • Save virattt/98fb3aa85603211bad87cc19b79fcfc5 to your computer and use it in GitHub Desktop.
Save virattt/98fb3aa85603211bad87cc19b79fcfc5 to your computer and use it in GitHub Desktop.
langchain-crag-financial-assistant.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOk31xq9//kMYrJkaoArNWL",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/98fb3aa85603211bad87cc19b79fcfc5/langchain-crag-financial-assistant.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aSdU1s56Pouh"
},
"outputs": [],
"source": [
"! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python pypdf"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "zSJ6772gP0sJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set your Tavily API key\n",
"os.environ[\"TAVILY_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "ppNjc7LXP3F9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Download and prepare SEC filing"
],
"metadata": {
"id": "sz639zFf6JoK"
}
},
{
"cell_type": "code",
"source": [
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n",
"sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n",
"\n",
"# Create your PDF loader\n",
"loader = PyPDFLoader(sec_filing_pdf)\n",
"\n",
"# Load the PDF document\n",
"documents = loader.load()\n",
"\n",
"# Chunk the financial report\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
],
"metadata": {
"id": "rIO5t-j7611h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Load the SEC filing into vector store"
],
"metadata": {
"id": "iaYSqxiMLUGb"
}
},
{
"cell_type": "code",
"source": [
"from langchain_community.vectorstores import Chroma\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"\n",
"# Load the document into Chroma\n",
"embedding_function = OpenAIEmbeddings()\n",
"vectorstore = Chroma.from_documents(docs, embedding_function)\n",
"\n",
"retriever = vectorstore.as_retriever()"
],
"metadata": {
"id": "QVZevdc-Md4N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Define graph State"
],
"metadata": {
"id": "4m0Vu0YhRgAc"
}
},
{
"cell_type": "code",
"source": [
"from typing import Dict, TypedDict\n",
"\n",
"from langchain_core.messages import BaseMessage\n",
"\n",
"\n",
"class GraphState(TypedDict):\n",
" \"\"\"\n",
" Represents the state of our graph.\n",
"\n",
" Attributes:\n",
" keys: A dictionary where each key is a string.\n",
" \"\"\"\n",
"\n",
" keys: Dict[str, any]"
],
"metadata": {
"id": "ztIym-lyRhgc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Define the graph's Nodes and Edges"
],
"metadata": {
"id": "opHvs3lrRmF7"
}
},
{
"cell_type": "code",
"source": [
"import json\n",
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"\n",
"from langchain import hub\n",
"from langchain.output_parsers.openai_tools import PydanticToolsParser\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema import Document\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_core.messages import BaseMessage, FunctionMessage\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"\n",
"### Nodes ###\n",
"\n",
"\n",
"def retrieve(state):\n",
" \"\"\"\n",
" Retrieve documents\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): New key added to state, documents, that contains retrieved documents\n",
" \"\"\"\n",
" print(\"---RETRIEVE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = retriever.get_relevant_documents(question)\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"\n",
"\n",
"def generate(state):\n",
" \"\"\"\n",
" Generate answer\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): New key added to state, generation, that contains LLM generation\n",
" \"\"\"\n",
" print(\"---GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Prompt\n",
" prompt = hub.pull(\"rlm/rag-prompt\")\n",
"\n",
" # LLM\n",
" llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0, streaming=True)\n",
"\n",
" # Post-processing\n",
" def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
" # Chain\n",
" rag_chain = prompt | llm | StrOutputParser()\n",
"\n",
" # Run\n",
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
" return {\n",
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
" }\n",
"\n",
"\n",
"def grade_documents(state):\n",
" \"\"\"\n",
" Determines whether the retrieved documents are relevant to the question.\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): Updates documents key with relevant documents\n",
" \"\"\"\n",
"\n",
" print(\"---CHECK RELEVANCE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
"\n",
" binary_score: str = Field(description=\"Relevance score 'yes' or 'no'\")\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
"\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
"\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
"\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n",
" Here is the retrieved document: \\n\\n {context} \\n\\n\n",
" Here is the user question: {question} \\n\n",
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n",
" input_variables=[\"context\", \"question\"],\n",
" )\n",
"\n",
" # Chain\n",
" chain = prompt | llm_with_tool | parser_tool\n",
"\n",
" # Score\n",
" filtered_docs = []\n",
" search = \"No\" # Default do not opt for web search to supplement retrieval\n",
" for d in documents:\n",
" score = chain.invoke({\"question\": question, \"context\": d.page_content})\n",
" grade = score[0].binary_score\n",
" if grade == \"yes\":\n",
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
" filtered_docs.append(d)\n",
" else:\n",
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
" search = \"Yes\" # Perform web search\n",
" continue\n",
"\n",
" return {\n",
" \"keys\": {\n",
" \"documents\": filtered_docs,\n",
" \"question\": question,\n",
" \"run_web_search\": search,\n",
" }\n",
" }\n",
"\n",
"\n",
"def transform_query(state):\n",
" \"\"\"\n",
" Transform the query to produce a better question.\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): Updates question key with a re-phrased question\n",
" \"\"\"\n",
"\n",
" print(\"---TRANSFORM QUERY---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Create a prompt template with format instructions and the query\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are generating questions that is well optimized for retrieval. \\n\n",
" Look at the input and try to reason about the underlying sematic intent / meaning. \\n\n",
" Here is the initial question:\n",
" \\n ------- \\n\n",
" {question}\n",
" \\n ------- \\n\n",
" Formulate an improved question: \"\"\",\n",
" input_variables=[\"question\"],\n",
" )\n",
"\n",
" # Grader\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Prompt\n",
" chain = prompt | model | StrOutputParser()\n",
" better_question = chain.invoke({\"question\": question})\n",
"\n",
" return {\"keys\": {\"documents\": documents, \"question\": better_question}}\n",
"\n",
"\n",
"def web_search(state):\n",
" \"\"\"\n",
" Web search based on the re-phrased question using Tavily API.\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): Updates documents key with appended web results\n",
" \"\"\"\n",
"\n",
" print(\"---WEB SEARCH---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" tool = TavilySearchResults()\n",
" docs = tool.invoke({\"query\": question})\n",
" web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
" web_results = Document(page_content=web_results)\n",
" documents.append(web_results)\n",
"\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"\n",
"\n",
"### Edges\n",
"\n",
"\n",
"def decide_to_generate(state):\n",
" \"\"\"\n",
" Determines whether to generate an answer or re-generate a question for web search.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" str: Next node to call\n",
" \"\"\"\n",
"\n",
" print(\"---DECIDE TO GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" filtered_documents = state_dict[\"documents\"]\n",
" search = state_dict[\"run_web_search\"]\n",
"\n",
" if search == \"Yes\":\n",
" # All documents have been filtered check_relevance\n",
" # We will re-generate a new query\n",
" print(\"---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---\")\n",
" return \"transform_query\"\n",
" else:\n",
" # We have relevant documents, so generate answer\n",
" print(\"---DECISION: GENERATE---\")\n",
" return \"generate\""
],
"metadata": {
"id": "mdVk88QkRoxl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Build the graph"
],
"metadata": {
"id": "3vNh_m4bRq4r"
}
},
{
"cell_type": "code",
"source": [
"import pprint\n",
"\n",
"from langgraph.graph import END, StateGraph\n",
"\n",
"workflow = StateGraph(GraphState)\n",
"\n",
"# Define the nodes\n",
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
"workflow.add_node(\"generate\", generate) # generatae\n",
"workflow.add_node(\"transform_query\", transform_query) # transform_query\n",
"workflow.add_node(\"web_search\", web_search) # web search\n",
"\n",
"# Build graph\n",
"workflow.set_entry_point(\"retrieve\")\n",
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
"workflow.add_conditional_edges(\n",
" \"grade_documents\",\n",
" decide_to_generate,\n",
" {\n",
" \"transform_query\": \"transform_query\",\n",
" \"generate\": \"generate\",\n",
" },\n",
")\n",
"workflow.add_edge(\"transform_query\", \"web_search\")\n",
"workflow.add_edge(\"web_search\", \"generate\")\n",
"workflow.add_edge(\"generate\", END)\n",
"\n",
"# Compile\n",
"app = workflow.compile()"
],
"metadata": {
"id": "du5RScfyRr48"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Run the graph"
],
"metadata": {
"id": "KJC48i6sRvsV"
}
},
{
"cell_type": "code",
"source": [
"# Run\n",
"question = \"What was Airbnb's revenue in Q3 2023?\"\n",
"inputs = {\"keys\": {\"question\": question}}\n",
"print(f\"Question: {question}\\n\")\n",
"for output in app.stream(inputs):\n",
" for key, value in output.items():\n",
" # Print Node\n",
" print()\n",
"\n",
"# Final generation\n",
"answer = value['keys']['generation']\n",
"print(f\"Answer: {answer}\")"
],
"metadata": {
"id": "t0dCdGbURwcT"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment