Last active
January 18, 2024 22:19
-
-
Save virattt/f501b6f6a6509fb18844bc307e65d4d2 to your computer and use it in GitHub Desktop.
financial-agent-parallel-tools.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPy2WkwVbyL9rGrOrbixUql", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/f501b6f6a6509fb18844bc307e65d4d2/financial-agent-parallel-tools.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "GcvSoNAmVbXH" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -U langchain langchain_openai langchainhub\n", | |
"!pip install -U polygon-api-client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "aw9453mxY2GZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.tools import tool\n", | |
"from polygon import RESTClient\n", | |
"from langchain.schema import HumanMessage, SystemMessage\n", | |
"from langchain_openai import ChatOpenAI\n", | |
"\n", | |
"# You can get a free Polygon API key from here: https://polygon.io/\n", | |
"client = RESTClient(api_key=\"POLYGON_API_KEY\")\n", | |
"\n", | |
"\n", | |
"@tool\n", | |
"def extract_ticker(query: str) -> str:\n", | |
" \"\"\"\n", | |
" Given a user query, extracts the ticker from the query and returns it.\n", | |
" \"\"\"\n", | |
" llm = ChatOpenAI(model=\"gpt-4\")\n", | |
"\n", | |
" messages = [\n", | |
" SystemMessage(content=\"You extract the stock ticker from a given user query.\"),\n", | |
" HumanMessage(content=query),\n", | |
" ]\n", | |
" result = llm(messages)\n", | |
" ticker = result.content\n", | |
" return ticker\n", | |
"\n", | |
"\n", | |
"@tool\n", | |
"def latest_stock_price(ticker: str) -> str:\n", | |
" \"\"\"\n", | |
" Provides the latest stock price data for a given ticker.\n", | |
" Uses the Polygon API to retrieve stock price data.\n", | |
" \"\"\"\n", | |
" quote = client.get_last_quote(ticker=ticker)\n", | |
" price = quote.ask_price\n", | |
" return price" | |
], | |
"metadata": { | |
"id": "IZm-G5ZvbHS4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from operator import itemgetter\n", | |
"from typing import Union\n", | |
"\n", | |
"from langchain.output_parsers import JsonOutputToolsParser\n", | |
"from langchain_community.tools.convert_to_openai import (\n", | |
" format_tool_to_openai_tool,\n", | |
")\n", | |
"from langchain_core.runnables import (\n", | |
" Runnable,\n", | |
" RunnableLambda,\n", | |
" RunnableMap,\n", | |
" RunnablePassthrough,\n", | |
")\n", | |
"from langchain_openai import ChatOpenAI\n", | |
"\n", | |
"model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\")\n", | |
"tools = [extract_ticker, latest_stock_price]\n", | |
"model_with_tools = model.bind(tools=[format_tool_to_openai_tool(t) for t in tools])\n", | |
"tool_map = {tool.name: tool for tool in tools}\n", | |
"\n", | |
"\n", | |
"def call_tool(tool_invocation: dict) -> Union[str, Runnable]:\n", | |
" \"\"\"Function for dynamically constructing the end of the chain based on the model-selected tool.\"\"\"\n", | |
" tool = tool_map[tool_invocation[\"type\"]]\n", | |
" return RunnablePassthrough.assign(output=itemgetter(\"args\") | tool)\n", | |
"\n", | |
"\n", | |
"# .map() allows us to apply a function to a list of inputs.\n", | |
"call_tool_list = RunnableLambda(call_tool).map()\n", | |
"chain = model_with_tools | JsonOutputToolsParser() | call_tool_list" | |
], | |
"metadata": { | |
"id": "nTW70qJEV6-A" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"chain.invoke(\n", | |
" \"What is the latest stock price for AAPL, MSFT, and AMZN?\"\n", | |
")" | |
], | |
"metadata": { | |
"id": "kBwr25yAWDVB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment