Created
January 8, 2024 22:01
-
-
Save virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8 to your computer and use it in GitHub Desktop.
LangGraph-financial-agent.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyModMkmR0AEpUcTyc4HIFN0", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8/langgraph-financial-agent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook includes my code for creating a financial agent using [LangGraph](https://github.com/langchain-ai/langgraph).\n", | |
"\n", | |
"The agent has two tools:\n", | |
"\n", | |
"1. Extract a ticker from a user query.\n", | |
"2. Given a ticker, get its latest price using [Polygon](https://polygon.io/).\n", | |
"\n", | |
"You will need two things to run the code:\n", | |
"\n", | |
"1. OpenAI API key ([link](https://platform.openai.com/account/api-keys))\n", | |
"2. Polygon API key ([link](https://polygon.io/))\n", | |
"\n", | |
"I've tried to make the code as easy as possible to read and run. If you have any questions, please feel free to message me on [X](https://twitter.com/virattt)!" | |
], | |
"metadata": { | |
"id": "9gfXnsep-Wuk" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 0 - Install dependencies" | |
], | |
"metadata": { | |
"id": "UpH-efxS-TAS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "GcvSoNAmVbXH" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install langgraph\n", | |
"!pip install -U langchain langchain_openai langchainhub\n", | |
"!pip install -U polygon-api-client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "aw9453mxY2GZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 1 - Define the tools that our agent will use" | |
], | |
"metadata": { | |
"id": "kLb5xTf399gT" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.tools import tool\n", | |
"from polygon import RESTClient\n", | |
"from langchain.schema import HumanMessage, SystemMessage\n", | |
"from langchain_openai import ChatOpenAI\n", | |
"\n", | |
"# You can get a free Polygon API key from: https://polygon.io/\n", | |
"client = RESTClient(api_key=\"YOUR_POLYGON_API_KEY\")\n", | |
"\n", | |
"\n", | |
"@tool\n", | |
"def extract_ticker(query: str) -> str:\n", | |
" \"\"\"\n", | |
" Given a user query, extracts the ticker from the query and returns it.\n", | |
" \"\"\"\n", | |
" llm = ChatOpenAI(model=\"gpt-4\")\n", | |
"\n", | |
" messages = [\n", | |
" SystemMessage(content=\"You extract the stock ticker from a given user query.\"),\n", | |
" HumanMessage(content=query),\n", | |
" ]\n", | |
" result = llm(messages)\n", | |
" ticker = result.content\n", | |
" print(f\"extracted ticker: {ticker}\")\n", | |
" return ticker\n", | |
"\n", | |
"\n", | |
"@tool\n", | |
"def latest_stock_price(ticker: str) -> str:\n", | |
" \"\"\"\n", | |
" Provides the latest stock price data for a given ticker.\n", | |
" Uses the Polygon API to retrieve stock price data.\n", | |
" \"\"\"\n", | |
" quote = client.get_last_quote(ticker=ticker)\n", | |
" price = quote.ask_price\n", | |
" print(f\"price quote: {price}\")\n", | |
" return price" | |
], | |
"metadata": { | |
"id": "IZm-G5ZvbHS4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 2 - Create the LLM" | |
], | |
"metadata": { | |
"id": "pSqKl9AT-DGL" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain import hub\n", | |
"from langchain.agents import create_openai_functions_agent\n", | |
"from langchain_openai.chat_models import ChatOpenAI\n", | |
"\n", | |
"tools = [latest_stock_price, extract_ticker]\n", | |
"\n", | |
"# Get the prompt to use - you can modify this!\n", | |
"prompt = hub.pull(\"hwchase17/openai-functions-agent\")\n", | |
"\n", | |
"# Choose the LLM that will drive the agent\n", | |
"llm = ChatOpenAI(model=\"gpt-4\")\n", | |
"\n", | |
"# Construct the OpenAI Functions agent\n", | |
"agent_runnable = create_openai_functions_agent(llm, tools, prompt)" | |
], | |
"metadata": { | |
"id": "mczr2sUEVdYv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 3 - Define the agent and related functions" | |
], | |
"metadata": { | |
"id": "AlTt85lW-GIY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain_core.runnables import RunnablePassthrough\n", | |
"from langchain_core.agents import AgentFinish\n", | |
"\n", | |
"\n", | |
"# Define the agent\n", | |
"agent = RunnablePassthrough.assign(\n", | |
" agent_outcome = agent_runnable\n", | |
")\n", | |
"\n", | |
"# Define the function to execute tools\n", | |
"def execute_tools(data):\n", | |
" agent_action = data.pop('agent_outcome')\n", | |
" tool_to_use = {t.name: t for t in tools}[agent_action.tool]\n", | |
" observation = tool_to_use.invoke(agent_action.tool_input)\n", | |
" data['intermediate_steps'].append((agent_action, observation))\n", | |
" return data\n", | |
"\n", | |
"# Define logic that will be used to determine which conditional edge to go down\n", | |
"def should_continue(data):\n", | |
" if isinstance(data['agent_outcome'], AgentFinish):\n", | |
" return \"exit\"\n", | |
" else:\n", | |
" return \"continue\"" | |
], | |
"metadata": { | |
"id": "nTW70qJEV6-A" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 4 - Define and compile the agent graph" | |
], | |
"metadata": { | |
"id": "RewrdOZf-J0S" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langgraph.graph import END, Graph\n", | |
"\n", | |
"workflow = Graph()\n", | |
"\n", | |
"workflow.add_node(\"agent\", agent)\n", | |
"workflow.add_node(\"tools\", execute_tools)\n", | |
"\n", | |
"# Set the entrypoint as `agent`\n", | |
"workflow.set_entry_point(\"agent\")\n", | |
"\n", | |
"# Add the edges of our agent graph\n", | |
"workflow.add_conditional_edges(\n", | |
" \"agent\",\n", | |
" should_continue,\n", | |
" {\n", | |
" \"continue\": \"tools\",\n", | |
" \"exit\": END\n", | |
" }\n", | |
")\n", | |
"workflow.add_edge('tools', 'agent')\n", | |
"\n", | |
"\n", | |
"# Compile the graph into a LangChain Runnable\n", | |
"chain = workflow.compile()" | |
], | |
"metadata": { | |
"id": "xK5euZ6UWAhq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 5 - Query our agent for the latest stock price!" | |
], | |
"metadata": { | |
"id": "MZRaj7rW-NDk" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"result = chain.invoke({\"input\": \"What is the latest stock price for AAPL?\", \"intermediate_steps\": []})\n", | |
"output = result['agent_outcome'].return_values[\"output\"]\n", | |
"print(output)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "kBwr25yAWDVB", | |
"outputId": "a52b20a2-ab2b-4aca-88bb-a1d0a323332f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"extracted ticker: AAPL\n", | |
"price quote: 185.36\n", | |
"The latest stock price for AAPL (Apple Inc.) is $185.36.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "7XeJKNsn9YWG" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment