Skip to content

Instantly share code, notes, and snippets.

@virattt
Created January 8, 2024 22:01
Show Gist options
  • Save virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8 to your computer and use it in GitHub Desktop.
Save virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8 to your computer and use it in GitHub Desktop.
LangGraph-financial-agent.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyModMkmR0AEpUcTyc4HIFN0",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/ba0b660cdcaf4161ca1e6e5d8b5de4f8/langgraph-financial-agent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This notebook includes my code for creating a financial agent using [LangGraph](https://github.com/langchain-ai/langgraph).\n",
"\n",
"The agent has two tools:\n",
"\n",
"1. Extract a ticker from a user query.\n",
"2. Given a ticker, get its latest price using [Polygon](https://polygon.io/).\n",
"\n",
"You will need two things to run the code:\n",
"\n",
"1. OpenAI API key ([link](https://platform.openai.com/account/api-keys))\n",
"2. Polygon API key ([link](https://polygon.io/))\n",
"\n",
"I've tried to make the code as easy as possible to read and run. If you have any questions, please feel free to message me on [X](https://twitter.com/virattt)!"
],
"metadata": {
"id": "9gfXnsep-Wuk"
}
},
{
"cell_type": "markdown",
"source": [
"## Step 0 - Install dependencies"
],
"metadata": {
"id": "UpH-efxS-TAS"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GcvSoNAmVbXH"
},
"outputs": [],
"source": [
"!pip install langgraph\n",
"!pip install -U langchain langchain_openai langchainhub\n",
"!pip install -U polygon-api-client"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"# Set your OpenAI API key\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
],
"metadata": {
"id": "aw9453mxY2GZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Step 1 - Define the tools that our agent will use"
],
"metadata": {
"id": "kLb5xTf399gT"
}
},
{
"cell_type": "code",
"source": [
"from langchain.tools import tool\n",
"from polygon import RESTClient\n",
"from langchain.schema import HumanMessage, SystemMessage\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"# You can get a free Polygon API key from: https://polygon.io/\n",
"client = RESTClient(api_key=\"YOUR_POLYGON_API_KEY\")\n",
"\n",
"\n",
"@tool\n",
"def extract_ticker(query: str) -> str:\n",
" \"\"\"\n",
" Given a user query, extracts the ticker from the query and returns it.\n",
" \"\"\"\n",
" llm = ChatOpenAI(model=\"gpt-4\")\n",
"\n",
" messages = [\n",
" SystemMessage(content=\"You extract the stock ticker from a given user query.\"),\n",
" HumanMessage(content=query),\n",
" ]\n",
" result = llm(messages)\n",
" ticker = result.content\n",
" print(f\"extracted ticker: {ticker}\")\n",
" return ticker\n",
"\n",
"\n",
"@tool\n",
"def latest_stock_price(ticker: str) -> str:\n",
" \"\"\"\n",
" Provides the latest stock price data for a given ticker.\n",
" Uses the Polygon API to retrieve stock price data.\n",
" \"\"\"\n",
" quote = client.get_last_quote(ticker=ticker)\n",
" price = quote.ask_price\n",
" print(f\"price quote: {price}\")\n",
" return price"
],
"metadata": {
"id": "IZm-G5ZvbHS4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Step 2 - Create the LLM"
],
"metadata": {
"id": "pSqKl9AT-DGL"
}
},
{
"cell_type": "code",
"source": [
"from langchain import hub\n",
"from langchain.agents import create_openai_functions_agent\n",
"from langchain_openai.chat_models import ChatOpenAI\n",
"\n",
"tools = [latest_stock_price, extract_ticker]\n",
"\n",
"# Get the prompt to use - you can modify this!\n",
"prompt = hub.pull(\"hwchase17/openai-functions-agent\")\n",
"\n",
"# Choose the LLM that will drive the agent\n",
"llm = ChatOpenAI(model=\"gpt-4\")\n",
"\n",
"# Construct the OpenAI Functions agent\n",
"agent_runnable = create_openai_functions_agent(llm, tools, prompt)"
],
"metadata": {
"id": "mczr2sUEVdYv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Step 3 - Define the agent and related functions"
],
"metadata": {
"id": "AlTt85lW-GIY"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.agents import AgentFinish\n",
"\n",
"\n",
"# Define the agent\n",
"agent = RunnablePassthrough.assign(\n",
" agent_outcome = agent_runnable\n",
")\n",
"\n",
"# Define the function to execute tools\n",
"def execute_tools(data):\n",
" agent_action = data.pop('agent_outcome')\n",
" tool_to_use = {t.name: t for t in tools}[agent_action.tool]\n",
" observation = tool_to_use.invoke(agent_action.tool_input)\n",
" data['intermediate_steps'].append((agent_action, observation))\n",
" return data\n",
"\n",
"# Define logic that will be used to determine which conditional edge to go down\n",
"def should_continue(data):\n",
" if isinstance(data['agent_outcome'], AgentFinish):\n",
" return \"exit\"\n",
" else:\n",
" return \"continue\""
],
"metadata": {
"id": "nTW70qJEV6-A"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Step 4 - Define and compile the agent graph"
],
"metadata": {
"id": "RewrdOZf-J0S"
}
},
{
"cell_type": "code",
"source": [
"from langgraph.graph import END, Graph\n",
"\n",
"workflow = Graph()\n",
"\n",
"workflow.add_node(\"agent\", agent)\n",
"workflow.add_node(\"tools\", execute_tools)\n",
"\n",
"# Set the entrypoint as `agent`\n",
"workflow.set_entry_point(\"agent\")\n",
"\n",
"# Add the edges of our agent graph\n",
"workflow.add_conditional_edges(\n",
" \"agent\",\n",
" should_continue,\n",
" {\n",
" \"continue\": \"tools\",\n",
" \"exit\": END\n",
" }\n",
")\n",
"workflow.add_edge('tools', 'agent')\n",
"\n",
"\n",
"# Compile the graph into a LangChain Runnable\n",
"chain = workflow.compile()"
],
"metadata": {
"id": "xK5euZ6UWAhq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Step 5 - Query our agent for the latest stock price!"
],
"metadata": {
"id": "MZRaj7rW-NDk"
}
},
{
"cell_type": "code",
"source": [
"result = chain.invoke({\"input\": \"What is the latest stock price for AAPL?\", \"intermediate_steps\": []})\n",
"output = result['agent_outcome'].return_values[\"output\"]\n",
"print(output)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kBwr25yAWDVB",
"outputId": "a52b20a2-ab2b-4aca-88bb-a1d0a323332f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"extracted ticker: AAPL\n",
"price quote: 185.36\n",
"The latest stock price for AAPL (Apple Inc.) is $185.36.\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "7XeJKNsn9YWG"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment