Created
November 13, 2023 19:21
-
-
Save ohmeow/fdc6a2bac0e31f433bd4657b662d83e5 to your computer and use it in GitHub Desktop.
langchain_tracing_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"private_outputs": true, | |
"provenance": [], | |
"authorship_tag": "ABX9TyPV3goWAds22VAUYd1InUY0", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ohmeow/fdc6a2bac0e31f433bd4657b662d83e5/langchain_tracing_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"! pip install cohere langchain openai tiktoken -qqq" | |
], | |
"metadata": { | |
"id": "Ub3pdAxwCvNf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Tracing with LangChain" | |
], | |
"metadata": { | |
"id": "tVhnNcIYGc5V" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"from uuid import UUID\n", | |
"from getpass import getpass\n", | |
"\n", | |
"from langchain.callbacks.base import BaseCallbackHandler\n", | |
"from langchain.callbacks.tracers import ConsoleCallbackHandler\n", | |
"from langchain.chains.openai_functions import create_structured_output_chain\n", | |
"from langchain.chat_models import ChatOpenAI\n", | |
"from langchain.chat_models.base import BaseChatModel\n", | |
"from langchain.load.load import load as lc_load_serialized_d\n", | |
"from langchain.prompts import (\n", | |
" ChatPromptTemplate,\n", | |
" HumanMessagePromptTemplate,\n", | |
" PromptTemplate,\n", | |
" SystemMessagePromptTemplate,\n", | |
" AIMessagePromptTemplate,\n", | |
")\n", | |
"from langchain.prompts.chat import BaseStringMessagePromptTemplate\n", | |
"from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage\n", | |
"from langchain.pydantic_v1 import BaseModel, Field" | |
], | |
"metadata": { | |
"id": "whueMlymCuKh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"os.environ['OPENAI_API_KEY'] = getpass(\"Enter your OpenAI API Key: \")" | |
], | |
"metadata": { | |
"id": "tR1z5MDeGpgu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Code" | |
], | |
"metadata": { | |
"id": "-M0OOJpWGqns" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "4M6HrjIiCdTP" | |
}, | |
"outputs": [], | |
"source": [ | |
"# |export\n", | |
"def get_message_info(\n", | |
" message_or_prompt: BaseMessage | BaseStringMessagePromptTemplate | PromptTemplate,\n", | |
" inputs: dict | None = None,\n", | |
"):\n", | |
" if isinstance(message_or_prompt, BaseMessage):\n", | |
" role = message_or_prompt.type\n", | |
" content = message_or_prompt.content\n", | |
" else:\n", | |
" if isinstance(message_or_prompt, SystemMessagePromptTemplate):\n", | |
" role = SystemMessage(content=\"_\").type\n", | |
" elif isinstance(message_or_prompt, HumanMessagePromptTemplate):\n", | |
" role = HumanMessage(content=\"_\").type\n", | |
" elif isinstance(message_or_prompt, AIMessagePromptTemplate):\n", | |
" role = AIMessage(content=\"_\").type\n", | |
" else:\n", | |
" role = \"prompt\"\n", | |
"\n", | |
" if inputs is None:\n", | |
" inputs = {}\n", | |
" content = (\n", | |
" message_or_prompt.prompt.template.format(**inputs)\n", | |
" if isinstance(message_or_prompt, BaseStringMessagePromptTemplate)\n", | |
" else message_or_prompt.template.format(**inputs)\n", | |
" )\n", | |
"\n", | |
" return {role: content}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#|export\n", | |
"class ChainTracer(BaseCallbackHandler):\n", | |
" def __init__(\n", | |
" self,\n", | |
" function_name: str | None = None,\n", | |
" ) -> None:\n", | |
" super().__init__()\n", | |
"\n", | |
" self.function_name = function_name\n", | |
" self.trace = {\n", | |
" \"model_name\": None,\n", | |
" \"messages\": [],\n", | |
" \"outputs\": [],\n", | |
" }\n", | |
"\n", | |
" def on_llm_end(self, response, **kwargs):\n", | |
" if hasattr(response, \"generations\"):\n", | |
" for generation in response.generations:\n", | |
" for item in generation:\n", | |
" if item.generation_info.get(\"finish_reason\") != \"stop\":\n", | |
" continue\n", | |
"\n", | |
" if hasattr(item, \"text\") and item.text != \"\":\n", | |
" self.trace[\"outputs\"].append(\n", | |
" {\"type\": \"text\", \"response\": item.text}\n", | |
" )\n", | |
" elif \"function_call\" in item.message.additional_kwargs:\n", | |
" self.trace[\"outputs\"].append(\n", | |
" {\n", | |
" \"function_call\": {\n", | |
" \"name\": self.function_name\n", | |
" or item.message.additional_kwargs[\"function_call\"][\n", | |
" \"name\"\n", | |
" ],\n", | |
" \"arguments\": item.message.additional_kwargs[\n", | |
" \"function_call\"\n", | |
" ][\"arguments\"].replace('{\"output\":', \"\")[:-1],\n", | |
" }\n", | |
" }\n", | |
" )\n", | |
"\n", | |
" if hasattr(response, \"llm_output\") and \"model_name\" in response.llm_output:\n", | |
" self.trace[\"model_name\"] = response.llm_output[\"model_name\"]\n", | |
"\n", | |
" def on_chain_start(\n", | |
" self,\n", | |
" serialized: dict,\n", | |
" inputs: dict,\n", | |
" *,\n", | |
" run_id: UUID,\n", | |
" parent_run_id: UUID | None = None,\n", | |
" tags: list[str] | None = None,\n", | |
" metadata: dict | None = None,\n", | |
" **kwargs\n", | |
" ):\n", | |
" prompt_template = lc_load_serialized_d(serialized[\"kwargs\"][\"prompt\"])\n", | |
" if isinstance(prompt_template, ChatPromptTemplate):\n", | |
" self.trace[\"messages\"].extend(\n", | |
" [get_message_info(m, inputs) for m in prompt_template.messages]\n", | |
" )\n", | |
" else:\n", | |
" self.trace[\"messages\"].extend(\n", | |
" [get_message_info(m, inputs) for m in [prompt_template]]\n", | |
" )" | |
], | |
"metadata": { | |
"id": "fRMawAfFCkHc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Examples" | |
], | |
"metadata": { | |
"id": "kCB3o_LlGu-T" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### A simple `LLMChain`" | |
], | |
"metadata": { | |
"id": "mi18-VTCElGW" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.chains import LLMChain\n", | |
"\n", | |
"chain_tracer = ChainTracer()\n", | |
"cbs = [chain_tracer]#, ConsoleCallbackHandler()]\n", | |
"\n", | |
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0, )\n", | |
"chain = LLMChain(llm=llm, prompt=PromptTemplate.from_template(\"{input}\"))\n", | |
"rsp = chain.run({\"input\": \"What is 2+2?\"}, callbacks=cbs)\n", | |
"\n", | |
"print(rsp)\n", | |
"print(chain_tracer.trace)" | |
], | |
"metadata": { | |
"id": "Ot2uesceEEbP" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### A simple Function" | |
], | |
"metadata": { | |
"id": "4_kDDgsWEnug" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Sentiment(BaseModel):\n", | |
" \"\"\"The sentiment found in a text\"\"\"\n", | |
" positivity:float = Field(..., description=\"How positive is this document on a scale of 0 (very negative) to 1 (ver positive):\")\n", | |
" is_employee:bool = Field (..., description=\"Is this text from an employee or external client?\")\n", | |
" recommendations:list[str] = Field(..., description=\"A list of any specific recommendations made by the author.\")" | |
], | |
"metadata": { | |
"id": "LLx_2IZ-Eicr" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n", | |
"chain_tracer = ChainTracer(function_name=\"get_sentiment\") # optional: will use `function_name` in output if not specified\n", | |
"\n", | |
"prompt_msgs = [\n", | |
" HumanMessagePromptTemplate.from_template(\"{input}\"),\n", | |
"]\n", | |
"\n", | |
"prompt = ChatPromptTemplate(messages=prompt_msgs, callbacks=[chain_tracer])\n", | |
"chain = create_structured_output_chain(Sentiment, llm, prompt, verbose=True)" | |
], | |
"metadata": { | |
"id": "AKST79QVFtGF" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsp = chain.run(\"As a member of the staff I'm really unhappy with the lack of bonuses! I recommend we reinstate bonuses and buy every employee a new Tesla\", callbacks=[tracer])\n", | |
"\n", | |
"print(rsp)\n", | |
"print(chain_tracer.trace)" | |
], | |
"metadata": { | |
"id": "luDW7saIF9ca" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsp = chain.run(\"My order took too long! I'll never use this company again. You need to get orders out faster and your customer services needs to be improved.\", callbacks=[chain_tracer])\n", | |
"\n", | |
"print(rsp)\n", | |
"print(chain_tracer.trace)" | |
], | |
"metadata": { | |
"id": "qzJvdun5GSCg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment