Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active March 23, 2024 15:30
Show Gist options
  • Save virattt/e2d69f7d5c95aeee611cde14bf507463 to your computer and use it in GitHub Desktop.
Save virattt/e2d69f7d5c95aeee611cde14bf507463 to your computer and use it in GitHub Desktop.
agent_with_custom_tool.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"N3x8LxDubC3B"
],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/e2d69f7d5c95aeee611cde14bf507463/agent_with_custom_tool.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Custom Agent and Custom Tool with LangChain\n",
"This notebook contains code for creating a custom:\n",
"1. Tool that \"reads\" annual reports\n",
"2. Agent that uses tools to answer queries\n",
"\n",
"In our example, the PDF is an annual report for Meta Platforms (formerly known as Facebook).\n",
"\n",
"To maximize your and my learning, I implemented the custom agent from scratch.\n",
"\n",
"I hope you find this code useful. Please follow me on https://twitter.com/virattt for more tutorials like this.\n",
"\n",
"Happy learning! :)"
],
"metadata": {
"id": "gmYJu6M1eRMA"
}
},
{
"cell_type": "markdown",
"source": [
"# Step 0. Install dependencies"
],
"metadata": {
"id": "N3x8LxDubC3B"
}
},
{
"cell_type": "code",
"source": [
"pip install langchain"
],
"metadata": {
"id": "zMINRKTEZ7Gd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install openai"
],
"metadata": {
"id": "hA-MpsaOP8DC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install chromadb"
],
"metadata": {
"id": "1AytsWBIQ6Zh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install tiktoken"
],
"metadata": {
"id": "D82QKqEpUD-B"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install pypdf"
],
"metadata": {
"id": "QdAGqGU0N0N0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 1. PDF Document Ingestion"
],
"metadata": {
"id": "WP4kma9pWDts"
}
},
{
"cell_type": "code",
"source": [
"from langchain.document_loaders import PyPDFLoader\n",
"\n",
"# Load $META's annual report. This may take 1-2 minutes since the PDF is 171 pages\n",
"meta_annual_report_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001326801/e574646c-c642-42d9-9229-3892b13aabfb.pdf\"\n",
"# Create your PDF loader\n",
"loader = PyPDFLoader(meta_annual_report_pdf)\n",
"# Load the PDF document\n",
"documents = loader.load() "
],
"metadata": {
"id": "vfh_g5uaVx55"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"# Chunk the annual_report\n",
"docs = text_splitter.split_documents(documents)"
],
"metadata": {
"id": "MaIuMZx7V0pj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 2. Save the annual report\n",
"Using ChromaDB, save the annual report to a vector database. \n",
"\n",
"This will allow your custom Agent and Tool to later retrieve (use) the annual report for question-answering."
],
"metadata": {
"id": "6Mfq-hk5Vni7"
}
},
{
"cell_type": "code",
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import Chroma\n",
"\n",
"OPENAI_API_KEY = \"YOUR_OPENAI_API_KEY\"\n",
"\n",
"embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)\n",
"vectorstore = Chroma.from_documents(docs, embeddings)"
],
"metadata": {
"id": "SLfnioDuPiUY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 3. Create your custom Chain\n",
"This Chain will be used by your custom Tool (defined next) to answer questions\n",
"about the annual report that you previously loaded."
],
"metadata": {
"id": "tBPm7mj2VUKd"
}
},
{
"cell_type": "code",
"source": [
"from langchain.chains.base import Chain\n",
"from typing import Dict, List\n",
"\n",
"class AnnualReportChain(Chain):\n",
" chain: Chain\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" return list(self.chain.input_keys)\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" return ['output']\n",
"\n",
" def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:\n",
" # Queries the database to get the relevant documents for a given query\n",
" query = inputs.get(\"input_documents\", \"\")\n",
" docs = vectorstore.similarity_search(query, include_metadata=True)\n",
" output = chain.run(input_documents=docs, question=query)\n",
" return {'output': output}"
],
"metadata": {
"id": "7RkzoUBXVVJm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 4. Create your custom Tool\n",
"This tool will use the Chain that you just created, under the hood."
],
"metadata": {
"id": "gss7eALIWmPJ"
}
},
{
"cell_type": "code",
"source": [
"from langchain.agents import Tool\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.llms import OpenAI\n",
"\n",
"# Initialize your custom Chain\n",
"llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY, model_name=\"gpt-3.5-turbo\")\n",
"chain = load_qa_chain(llm)\n",
"annual_report_chain = AnnualReportChain(chain=chain)\n",
"\n",
"# Initialize your custom Tool\n",
"annual_report_tool = Tool(\n",
" name=\"Annual Report\",\n",
" func=annual_report_chain.run,\n",
" description=\"\"\"\n",
" useful for when you need to answer questions about a company's income statement,\n",
" cash flow statement, or balance sheet. This tool can help you extract data points like\n",
" net income, revenue, free cash flow, and total debt, among other financial line items.\n",
" \"\"\"\n",
")"
],
"metadata": {
"id": "yh4iYLeJWEpr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 5. Create your custom Agent\n",
"This Agent uses your custom tool(s) to get things done.\n",
"\n",
"For our example, the Agent is given 1 tool (`annual_report_tool` from above) and answers questions about annual reports!\n",
"\n",
"The code here is heavily borrowed from [this wonderful GitHub repository](https://github.com/mpaepper/llm_agents), which is created by [Marc Päpper](https://twitter.com/mpaepper).\n",
"\n",
"Marc wrote an [excellent blog post](https://www.paepper.com/blog/posts/intelligent-agents-guided-by-llms/) that explains how Agents work.\n"
],
"metadata": {
"id": "F0PC0OPbVP7x"
}
},
{
"cell_type": "code",
"source": [
"import re\n",
"\n",
"from pydantic import BaseModel\n",
"from typing import Tuple\n",
"\n",
"class Agent(BaseModel):\n",
" # The large language model that the Agent will use to decide the action to take\n",
" llm: BaseModel\n",
" # The prompt that the language model will use and append previous responses to\n",
" prompt: str\n",
" # The list of tools that the Agent can use\n",
" tools: List[Tool]\n",
" # Adjust this so that the Agent does not loop infinitely\n",
" max_loops: int = 5\n",
" # The stop pattern is used, so the LLM does not hallucinate until the end\n",
" stop_pattern: List[str]\n",
"\n",
" @property\n",
" def tool_by_names(self) -> Dict[str, Tool]:\n",
" return {tool.name: tool for tool in self.tools}\n",
"\n",
" def run(self, question: str):\n",
" name_to_tool_map = {tool.name: tool for tool in self.tools}\n",
" previous_responses = []\n",
" num_loops = 0\n",
" while num_loops < self.max_loops:\n",
" num_loops += 1\n",
" curr_prompt = prompt.format(previous_responses=('\\n'.join(previous_responses)))\n",
" output, tool, tool_input = self._get_next_action(curr_prompt)\n",
" if tool == 'Final Answer':\n",
" return tool_input\n",
" tool_result = name_to_tool_map[tool].run(tool_input)\n",
" output += f\"\\n{OBSERVATION_TOKEN} {tool_result}\\n{THOUGHT_TOKEN}\"\n",
" print(output)\n",
" previous_responses.append(output)\n",
"\n",
" def _get_next_action(self, prompt: str) -> Tuple[str, str, str]:\n",
" # Use the LLM to generate the Agent's next action\n",
" result = self.llm.generate([prompt], stop=self.stop_pattern)\n",
"\n",
" # List of the things generated. This is List[List[]] because each input could have multiple generations.\n",
" generations = result.generations\n",
"\n",
" # Grab the first text generation, as this will likely be the best result\n",
" output = generations[0][0].text\n",
"\n",
" # Parse the result\n",
" tool, tool_input = self._get_tool_and_input(output)\n",
" return output, tool, tool_input\n",
"\n",
" def _get_tool_and_input(self, generated: str) -> Tuple[str, str]:\n",
" if FINAL_ANSWER_TOKEN in generated:\n",
" return \"Final Answer\", generated.split(FINAL_ANSWER_TOKEN)[-1].strip()\n",
"\n",
" regex = r\"Action: [\\[]?(.*?)[\\]]?[\\n]*Action Input:[\\s]*(.*)\"\n",
" match = re.search(regex, generated, re.DOTALL)\n",
" if not match:\n",
" raise ValueError(f\"Output of LLM is not parsable for next tool use: `{generated}`\")\n",
" tool = match.group(1).strip()\n",
" tool_input = match.group(2)\n",
" return tool, tool_input.strip(\" \").strip('\"')"
],
"metadata": {
"id": "UBqb6bZkagyx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 6. Create your Prompt template\n",
"This prompt will be fed into the Agent's large language model (LLM). \n",
"\n",
"As it \"reasons\" and answers your query, the Agent will update this prompt by appending the previous response (context) to the prompt to maintain context of its overall \"chain of thought\"."
],
"metadata": {
"id": "_um0nrAlYcw_"
}
},
{
"cell_type": "code",
"source": [
"FINAL_ANSWER_TOKEN = \"Final Answer:\"\n",
"OBSERVATION_TOKEN = \"Observation:\"\n",
"THOUGHT_TOKEN = \"Thought:\"\n",
"PROMPT_TEMPLATE = \"\"\"Answer the question as best as you can using the following tools: \n",
"\n",
"{tool_description}\n",
"\n",
"Use the following format:\n",
"\n",
"Question: the input question you must answer\n",
"Thought: comment on what you want to do next\n",
"Action: the action to take, exactly one element of [{tool_names}]\n",
"Action Input: the input to the action\n",
"Observation: the result of the action\n",
"... (this Thought/Action/Action Input/Observation repeats N times, use it until you are sure of the answer)\n",
"Thought: I now know the final answer\n",
"Final Answer: your final answer to the original input question\n",
"\n",
"Begin!\n",
"\n",
"Question: {question}\n",
"Thought: {previous_responses}\n",
"\"\"\""
],
"metadata": {
"id": "LXVPO69KZ46c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Step 7. Run your custom Agent\n",
"You can update the `question` variable to ask your Agent to answer questions about the PDF that you previously loaded!\n"
],
"metadata": {
"id": "f-M4SKq7ZFC2"
}
},
{
"cell_type": "code",
"source": [
"# The tool(s) that your Agent will use\n",
"tools = [annual_report_tool]\n",
"\n",
"# The question that you will ask your Agent\n",
"question = \"What was Meta's net income in 2022? What was net income the year before that?\"\n",
"\n",
"# The prompt that your Agent will use and update as it is \"reasoning\"\n",
"prompt = PROMPT_TEMPLATE.format(\n",
" tool_description=\"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in tools]),\n",
" tool_names=\", \".join([tool.name for tool in tools]),\n",
" question=question,\n",
" previous_responses='{previous_responses}',\n",
")\n",
"\n",
"# The LLM that your Agent will use\n",
"llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY, model_name=\"gpt-3.5-turbo\")\n",
"\n",
"# Initialize your Agent\n",
"agent = Agent(\n",
" llm=llm, \n",
" tools=tools, \n",
" prompt=prompt, \n",
" stop_pattern=[f'\\n{OBSERVATION_TOKEN}', f'\\n\\t{OBSERVATION_TOKEN}'],\n",
")\n",
"\n",
"# Run the Agent!\n",
"result = agent.run(question)\n",
"\n",
"print(result)"
],
"metadata": {
"id": "sGj1TjyEapA9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"You can verify whether your Agent's answer about Meta's net income is correct [here](https://www.deepvalue.ai/explore/stocks/META) 😃"
],
"metadata": {
"id": "pWWYSKcNd7Ml"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "SD2NNc7ZawYD"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment