Skip to content

Instantly share code, notes, and snippets.

@ohmeow
Created November 13, 2023 19:21
Show Gist options
  • Save ohmeow/fdc6a2bac0e31f433bd4657b662d83e5 to your computer and use it in GitHub Desktop.
Save ohmeow/fdc6a2bac0e31f433bd4657b662d83e5 to your computer and use it in GitHub Desktop.
langchain_tracing_example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"authorship_tag": "ABX9TyPV3goWAds22VAUYd1InUY0",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ohmeow/fdc6a2bac0e31f433bd4657b662d83e5/langchain_tracing_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"! pip install cohere langchain openai tiktoken -qqq"
],
"metadata": {
"id": "Ub3pdAxwCvNf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Tracing with LangChain"
],
"metadata": {
"id": "tVhnNcIYGc5V"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"from uuid import UUID\n",
"from getpass import getpass\n",
"\n",
"from langchain.callbacks.base import BaseCallbackHandler\n",
"from langchain.callbacks.tracers import ConsoleCallbackHandler\n",
"from langchain.chains.openai_functions import create_structured_output_chain\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chat_models.base import BaseChatModel\n",
"from langchain.load.load import load as lc_load_serialized_d\n",
"from langchain.prompts import (\n",
" ChatPromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
" PromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" AIMessagePromptTemplate,\n",
")\n",
"from langchain.prompts.chat import BaseStringMessagePromptTemplate\n",
"from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage\n",
"from langchain.pydantic_v1 import BaseModel, Field"
],
"metadata": {
"id": "whueMlymCuKh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"os.environ['OPENAI_API_KEY'] = getpass(\"Enter your OpenAI API Key: \")"
],
"metadata": {
"id": "tR1z5MDeGpgu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Code"
],
"metadata": {
"id": "-M0OOJpWGqns"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4M6HrjIiCdTP"
},
"outputs": [],
"source": [
"# |export\n",
"def get_message_info(\n",
" message_or_prompt: BaseMessage | BaseStringMessagePromptTemplate | PromptTemplate,\n",
" inputs: dict | None = None,\n",
"):\n",
" if isinstance(message_or_prompt, BaseMessage):\n",
" role = message_or_prompt.type\n",
" content = message_or_prompt.content\n",
" else:\n",
" if isinstance(message_or_prompt, SystemMessagePromptTemplate):\n",
" role = SystemMessage(content=\"_\").type\n",
" elif isinstance(message_or_prompt, HumanMessagePromptTemplate):\n",
" role = HumanMessage(content=\"_\").type\n",
" elif isinstance(message_or_prompt, AIMessagePromptTemplate):\n",
" role = AIMessage(content=\"_\").type\n",
" else:\n",
" role = \"prompt\"\n",
"\n",
" if inputs is None:\n",
" inputs = {}\n",
" content = (\n",
" message_or_prompt.prompt.template.format(**inputs)\n",
" if isinstance(message_or_prompt, BaseStringMessagePromptTemplate)\n",
" else message_or_prompt.template.format(**inputs)\n",
" )\n",
"\n",
" return {role: content}"
]
},
{
"cell_type": "code",
"source": [
"#|export\n",
"class ChainTracer(BaseCallbackHandler):\n",
" def __init__(\n",
" self,\n",
" function_name: str | None = None,\n",
" ) -> None:\n",
" super().__init__()\n",
"\n",
" self.function_name = function_name\n",
" self.trace = {\n",
" \"model_name\": None,\n",
" \"messages\": [],\n",
" \"outputs\": [],\n",
" }\n",
"\n",
" def on_llm_end(self, response, **kwargs):\n",
" if hasattr(response, \"generations\"):\n",
" for generation in response.generations:\n",
" for item in generation:\n",
" if item.generation_info.get(\"finish_reason\") != \"stop\":\n",
" continue\n",
"\n",
" if hasattr(item, \"text\") and item.text != \"\":\n",
" self.trace[\"outputs\"].append(\n",
" {\"type\": \"text\", \"response\": item.text}\n",
" )\n",
" elif \"function_call\" in item.message.additional_kwargs:\n",
" self.trace[\"outputs\"].append(\n",
" {\n",
" \"function_call\": {\n",
" \"name\": self.function_name\n",
" or item.message.additional_kwargs[\"function_call\"][\n",
" \"name\"\n",
" ],\n",
" \"arguments\": item.message.additional_kwargs[\n",
" \"function_call\"\n",
" ][\"arguments\"].replace('{\"output\":', \"\")[:-1],\n",
" }\n",
" }\n",
" )\n",
"\n",
" if hasattr(response, \"llm_output\") and \"model_name\" in response.llm_output:\n",
" self.trace[\"model_name\"] = response.llm_output[\"model_name\"]\n",
"\n",
" def on_chain_start(\n",
" self,\n",
" serialized: dict,\n",
" inputs: dict,\n",
" *,\n",
" run_id: UUID,\n",
" parent_run_id: UUID | None = None,\n",
" tags: list[str] | None = None,\n",
" metadata: dict | None = None,\n",
" **kwargs\n",
" ):\n",
" prompt_template = lc_load_serialized_d(serialized[\"kwargs\"][\"prompt\"])\n",
" if isinstance(prompt_template, ChatPromptTemplate):\n",
" self.trace[\"messages\"].extend(\n",
" [get_message_info(m, inputs) for m in prompt_template.messages]\n",
" )\n",
" else:\n",
" self.trace[\"messages\"].extend(\n",
" [get_message_info(m, inputs) for m in [prompt_template]]\n",
" )"
],
"metadata": {
"id": "fRMawAfFCkHc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Examples"
],
"metadata": {
"id": "kCB3o_LlGu-T"
}
},
{
"cell_type": "markdown",
"source": [
"### A simple `LLMChain`"
],
"metadata": {
"id": "mi18-VTCElGW"
}
},
{
"cell_type": "code",
"source": [
"from langchain.chains import LLMChain\n",
"\n",
"chain_tracer = ChainTracer()\n",
"cbs = [chain_tracer]#, ConsoleCallbackHandler()]\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0, )\n",
"chain = LLMChain(llm=llm, prompt=PromptTemplate.from_template(\"{input}\"))\n",
"rsp = chain.run({\"input\": \"What is 2+2?\"}, callbacks=cbs)\n",
"\n",
"print(rsp)\n",
"print(chain_tracer.trace)"
],
"metadata": {
"id": "Ot2uesceEEbP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### A simple Function"
],
"metadata": {
"id": "4_kDDgsWEnug"
}
},
{
"cell_type": "code",
"source": [
"class Sentiment(BaseModel):\n",
" \"\"\"The sentiment found in a text\"\"\"\n",
" positivity:float = Field(..., description=\"How positive is this document on a scale of 0 (very negative) to 1 (ver positive):\")\n",
" is_employee:bool = Field (..., description=\"Is this text from an employee or external client?\")\n",
" recommendations:list[str] = Field(..., description=\"A list of any specific recommendations made by the author.\")"
],
"metadata": {
"id": "LLx_2IZ-Eicr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
"chain_tracer = ChainTracer(function_name=\"get_sentiment\") # optional: will use `function_name` in output if not specified\n",
"\n",
"prompt_msgs = [\n",
" HumanMessagePromptTemplate.from_template(\"{input}\"),\n",
"]\n",
"\n",
"prompt = ChatPromptTemplate(messages=prompt_msgs, callbacks=[chain_tracer])\n",
"chain = create_structured_output_chain(Sentiment, llm, prompt, verbose=True)"
],
"metadata": {
"id": "AKST79QVFtGF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rsp = chain.run(\"As a member of the staff I'm really unhappy with the lack of bonuses! I recommend we reinstate bonuses and buy every employee a new Tesla\", callbacks=[tracer])\n",
"\n",
"print(rsp)\n",
"print(chain_tracer.trace)"
],
"metadata": {
"id": "luDW7saIF9ca"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rsp = chain.run(\"My order took too long! I'll never use this company again. You need to get orders out faster and your customer services needs to be improved.\", callbacks=[chain_tracer])\n",
"\n",
"print(rsp)\n",
"print(chain_tracer.trace)"
],
"metadata": {
"id": "qzJvdun5GSCg"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment