Created
May 11, 2024 07:44
-
-
Save Sanjayy-ux/ef47f6efac3874394b9e62e497932a66 to your computer and use it in GitHub Desktop.
SalesGPT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from typing import Any, Callable, Dict, List, Union | |
from langchain.agents import ( | |
AgentExecutor, | |
LLMSingleActionAgent, | |
create_openai_tools_agent, | |
) | |
from langchain.chains import LLMChain, RetrievalQA | |
from langchain.chains.base import Chain | |
from langchain_community.chat_models import ChatLiteLLM | |
from langchain_core.agents import ( | |
_convert_agent_action_to_messages, | |
_convert_agent_observation_to_messages, | |
) | |
from langchain_core.language_models.llms import create_base_retry_decorator | |
from litellm import acompletion | |
from pydantic import Field | |
from salesgpt.chains import SalesConversationChain, StageAnalyzerChain | |
from salesgpt.custom_invoke import CustomAgentExecutor | |
from salesgpt.logger import time_logger | |
from salesgpt.parsers import SalesConvoOutputParser | |
from salesgpt.prompts import SALES_AGENT_TOOLS_PROMPT | |
from salesgpt.stages import CONVERSATION_STAGES | |
from salesgpt.templates import CustomPromptTemplateForTools | |
from salesgpt.tools import get_tools, setup_knowledge_base | |
def _create_retry_decorator(llm: Any) -> Callable[[Any], Any]: | |
""" | |
Creates a retry decorator for handling OpenAI API errors. | |
This function creates a retry decorator that will retry a function call | |
if it raises any of the specified OpenAI API errors. The maximum number of retries | |
is determined by the 'max_retries' attribute of the 'llm' object. | |
Args: | |
llm (Any): An object that has a 'max_retries' attribute specifying the maximum number of retries. | |
Returns: | |
Callable[[Any], Any]: A retry decorator. | |
""" | |
import openai | |
errors = [ | |
openai.Timeout, | |
openai.APIError, | |
openai.APIConnectionError, | |
openai.RateLimitError, | |
openai.APIStatusError, | |
] | |
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries) | |
class SalesGPT(Chain): | |
"""Controller model for the Sales Agent.""" | |
conversation_history: List[str] = [] | |
conversation_stage_id: str = "1" | |
current_conversation_stage: str = CONVERSATION_STAGES.get("1") | |
stage_analyzer_chain: StageAnalyzerChain = Field(...) | |
sales_agent_executor: Union[CustomAgentExecutor, None] = Field(...) | |
knowledge_base: Union[RetrievalQA, None] = Field(...) | |
sales_conversation_utterance_chain: SalesConversationChain = Field(...) | |
conversation_stage_dict: Dict = CONVERSATION_STAGES | |
model_name: str = "gpt-3.5-turbo-0613" # TODO - make this an env variable | |
use_tools: bool = False | |
salesperson_name: str = "Ted Lasso" | |
salesperson_role: str = "Business Development Representative" | |
company_name: str = "Sleep Haven" | |
company_business: str = "Sleep Haven is a premium mattress company that provides customers with the most comfortable and supportive sleeping experience possible. We offer a range of high-quality mattresses, pillows, and bedding accessories that are designed to meet the unique needs of our customers." | |
company_values: str = "Our mission at Sleep Haven is to help people achieve a better night's sleep by providing them with the best possible sleep solutions. We believe that quality sleep is essential to overall health and well-being, and we are committed to helping our customers achieve optimal sleep by offering exceptional products and customer service." | |
conversation_purpose: str = "find out whether they are looking to achieve better sleep via buying a premier mattress." | |
conversation_type: str = "call" | |
def retrieve_conversation_stage(self, key): | |
""" | |
Retrieves the conversation stage based on the provided key. | |
This function uses the key to look up the corresponding conversation stage in the conversation_stage_dict dictionary. | |
If the key is not found in the dictionary, it defaults to "1". | |
Args: | |
key (str): The key to look up in the conversation_stage_dict dictionary. | |
Returns: | |
str: The conversation stage corresponding to the key, or "1" if the key is not found. | |
""" | |
return self.conversation_stage_dict.get(key, "1") | |
@property | |
def input_keys(self) -> List[str]: | |
""" | |
Property that returns a list of input keys. | |
This property is currently set to return an empty list. It can be overridden in a subclass to return a list of keys | |
that are used to extract input data from a dictionary. | |
Returns: | |
List[str]: An empty list. | |
""" | |
return [] | |
@property | |
def output_keys(self) -> List[str]: | |
""" | |
Property that returns a list of output keys. | |
This property is currently set to return an empty list. It can be overridden in a subclass to return a list of keys | |
that are used to extract output data from a dictionary. | |
Returns: | |
List[str]: An empty list. | |
""" | |
return [] | |
@time_logger | |
def seed_agent(self): | |
""" | |
This method seeds the conversation by setting the initial conversation stage and clearing the conversation history. | |
The initial conversation stage is retrieved using the key "1". The conversation history is reset to an empty list. | |
Returns: | |
None | |
""" | |
self.current_conversation_stage = self.retrieve_conversation_stage("1") | |
self.conversation_history = [] | |
@time_logger | |
def determine_conversation_stage(self): | |
""" | |
Determines the current conversation stage based on the conversation history. | |
This method uses the stage_analyzer_chain to analyze the conversation history and determine the current stage. | |
The conversation history is joined into a single string, with each entry separated by a newline character. | |
The current conversation stage ID is also passed to the stage_analyzer_chain. | |
The method then prints the determined conversation stage ID and retrieves the corresponding conversation stage | |
from the conversation_stage_dict dictionary using the retrieve_conversation_stage method. | |
Finally, the method prints the determined conversation stage. | |
Returns: | |
None | |
""" | |
print(f"Conversation Stage ID before analysis: {self.conversation_stage_id}") | |
print("Conversation history:") | |
print(self.conversation_history) | |
stage_analyzer_output = self.stage_analyzer_chain.invoke( | |
input={ | |
"conversation_history": "\n".join(self.conversation_history).rstrip( | |
"\n" | |
), | |
"conversation_stage_id": self.conversation_stage_id, | |
"conversation_stages": "\n".join( | |
[ | |
str(key) + ": " + str(value) | |
for key, value in CONVERSATION_STAGES.items() | |
] | |
), | |
}, | |
return_only_outputs=False, | |
) | |
print("Stage analyzer output") | |
print(stage_analyzer_output) | |
self.conversation_stage_id = stage_analyzer_output.get("text") | |
self.current_conversation_stage = self.retrieve_conversation_stage( | |
self.conversation_stage_id | |
) | |
print(f"Conversation Stage: {self.current_conversation_stage}") | |
@time_logger | |
async def adetermine_conversation_stage(self): | |
""" | |
Determines the current conversation stage based on the conversation history. | |
This method uses the stage_analyzer_chain to analyze the conversation history and determine the current stage. | |
The conversation history is joined into a single string, with each entry separated by a newline character. | |
The current conversation stage ID is also passed to the stage_analyzer_chain. | |
The method then prints the determined conversation stage ID and retrieves the corresponding conversation stage | |
from the conversation_stage_dict dictionary using the retrieve_conversation_stage method. | |
Finally, the method prints the determined conversation stage. | |
Returns: | |
None | |
""" | |
print(f"Conversation Stage ID before analysis: {self.conversation_stage_id}") | |
print("Conversation history:") | |
print(self.conversation_history) | |
stage_analyzer_output = await self.stage_analyzer_chain.ainvoke( | |
input={ | |
"conversation_history": "\n".join(self.conversation_history).rstrip( | |
"\n" | |
), | |
"conversation_stage_id": self.conversation_stage_id, | |
"conversation_stages": "\n".join( | |
[ | |
str(key) + ": " + str(value) | |
for key, value in CONVERSATION_STAGES.items() | |
] | |
), | |
}, | |
return_only_outputs=False, | |
) | |
print("Stage analyzer output") | |
print(stage_analyzer_output) | |
self.conversation_stage_id = stage_analyzer_output.get("text") | |
self.current_conversation_stage = self.retrieve_conversation_stage( | |
self.conversation_stage_id | |
) | |
print(f"Conversation Stage: {self.current_conversation_stage}") | |
def human_step(self, human_input): | |
""" | |
Processes the human input and appends it to the conversation history. | |
This method takes the human input as a string, formats it by adding "User: " at the beginning and " <END_OF_TURN>" at the end, and then appends this formatted string to the conversation history. | |
Args: | |
human_input (str): The input string from the human user. | |
Returns: | |
None | |
""" | |
human_input = "User: " + human_input + " <END_OF_TURN>" | |
self.conversation_history.append(human_input) | |
@time_logger | |
def step(self, stream: bool = False): | |
""" | |
Executes a step in the conversation. If the stream argument is set to True, | |
it returns a streaming generator object for manipulating streaming chunks in downstream applications. | |
If the stream argument is set to False, it calls the _call method with an empty dictionary as input. | |
Args: | |
stream (bool, optional): A flag indicating whether to return a streaming generator object. | |
Defaults to False. | |
Returns: | |
Generator: A streaming generator object if stream is set to True. Otherwise, it returns None. | |
""" | |
if not stream: | |
return self._call(inputs={}) | |
else: | |
return self._streaming_generator() | |
@time_logger | |
async def astep(self, stream: bool = False): | |
""" | |
Executes an asynchronous step in the conversation. | |
If the stream argument is set to False, it calls the _acall method with an empty dictionary as input. | |
If the stream argument is set to True, it returns a streaming generator object for manipulating streaming chunks in downstream applications. | |
Args: | |
stream (bool, optional): A flag indicating whether to return a streaming generator object. | |
Defaults to False. | |
Returns: | |
Generator: A streaming generator object if stream is set to True. Otherwise, it returns None. | |
""" | |
if not stream: | |
return await self.acall(inputs={}) | |
else: | |
return await self._astreaming_generator() | |
@time_logger | |
async def acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Executes one step of the sales agent. | |
This function overrides the input temporarily with the current state of the conversation, | |
generates the agent's utterance using either the sales agent executor or the sales conversation utterance chain, | |
adds the agent's response to the conversation history, and returns the AI message. | |
Parameters | |
---------- | |
inputs : Dict[str, Any] | |
The initial inputs for the sales agent. | |
Returns | |
------- | |
Dict[str, Any] | |
The AI message generated by the sales agent. | |
""" | |
# override inputs temporarily | |
inputs = { | |
"input": "", | |
"conversation_stage": self.current_conversation_stage, | |
"conversation_history": "\n".join(self.conversation_history), | |
"salesperson_name": self.salesperson_name, | |
"salesperson_role": self.salesperson_role, | |
"company_name": self.company_name, | |
"company_business": self.company_business, | |
"company_values": self.company_values, | |
"conversation_purpose": self.conversation_purpose, | |
"conversation_type": self.conversation_type, | |
} | |
# Generate agent's utterance | |
if self.use_tools: | |
ai_message = await self.sales_agent_executor.ainvoke(inputs) | |
output = ai_message["output"] | |
else: | |
ai_message = await self.sales_conversation_utterance_chain.ainvoke( | |
inputs, return_intermediate_steps=True | |
) | |
output = ai_message["text"] | |
# Add agent's response to conversation history | |
agent_name = self.salesperson_name | |
output = agent_name + ": " + output | |
if "<END_OF_TURN>" not in output: | |
output += " <END_OF_TURN>" | |
self.conversation_history.append(output) | |
if self.verbose: | |
tool_status = "USE TOOLS INVOKE:" if self.use_tools else "WITHOUT TOOLS:" | |
print(f"{tool_status}\n#\n#\n#\n#\n------------------") | |
print(f"AI Message: {ai_message}") | |
print() | |
print(f"Output: {output.replace('<END_OF_TURN>', '')}") | |
return ai_message | |
@time_logger | |
def _prep_messages(self): | |
""" | |
Prepares a list of messages for the streaming generator. | |
This method prepares a list of messages based on the current state of the conversation. | |
The messages are prepared using the 'prep_prompts' method of the 'sales_conversation_utterance_chain' object. | |
The prepared messages include details about the current conversation stage, conversation history, salesperson's name and role, | |
company's name, business, values, conversation purpose, and conversation type. | |
Returns: | |
list: A list of prepared messages to be passed to a streaming generator. | |
""" | |
prompt = self.sales_conversation_utterance_chain.prep_prompts( | |
[ | |
dict( | |
conversation_stage=self.current_conversation_stage, | |
conversation_history="\n".join(self.conversation_history), | |
salesperson_name=self.salesperson_name, | |
salesperson_role=self.salesperson_role, | |
company_name=self.company_name, | |
company_business=self.company_business, | |
company_values=self.company_values, | |
conversation_purpose=self.conversation_purpose, | |
conversation_type=self.conversation_type, | |
) | |
] | |
) | |
inception_messages = prompt[0][0].to_messages() | |
message_dict = {"role": "system", "content": inception_messages[0].content} | |
if self.sales_conversation_utterance_chain.verbose: | |
pass | |
# print("\033[92m" + inception_messages[0].content + "\033[0m") | |
return [message_dict] | |
@time_logger | |
def _streaming_generator(self): | |
""" | |
Generates a streaming generator for partial LLM output manipulation. | |
This method is used when the sales agent needs to take an action before the full LLM output is available. | |
For example, when performing text to speech on the partial LLM output. The method returns a streaming generator | |
which can manipulate partial output from an LLM in-flight of the generation. | |
Returns | |
------- | |
generator | |
A streaming generator for manipulating partial LLM output. | |
Examples | |
-------- | |
>>> streaming_generator = self._streaming_generator() | |
>>> for chunk in streaming_generator: | |
... print(chunk) | |
Chunk 1, Chunk 2, ... etc. | |
See Also | |
-------- | |
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb | |
""" | |
messages = self._prep_messages() | |
return self.sales_conversation_utterance_chain.llm.completion_with_retry( | |
messages=messages, | |
stop="<END_OF_TURN>", | |
stream=True, | |
model=self.model_name, | |
) | |
async def acompletion_with_retry(self, llm: Any, **kwargs: Any) -> Any: | |
""" | |
Use tenacity to retry the async completion call. | |
This method uses the tenacity library to retry the asynchronous completion call in case of failure. | |
It creates a retry decorator using the '_create_retry_decorator' method and applies it to the | |
'_completion_with_retry' function which makes the actual asynchronous completion call. | |
Parameters | |
---------- | |
llm : Any | |
The language model to be used for the completion. | |
\*\*kwargs : Any | |
Additional keyword arguments to be passed to the completion function. | |
Returns | |
------- | |
Any | |
The result of the completion function call. | |
Raises | |
------ | |
Exception | |
If the completion function call fails after the maximum number of retries. | |
""" | |
retry_decorator = _create_retry_decorator(llm) | |
@retry_decorator | |
async def _completion_with_retry(**kwargs: Any) -> Any: | |
# Use OpenAI's async api https://github.com/openai/openai-python#async-api | |
return await acompletion(**kwargs) | |
return await _completion_with_retry(**kwargs) | |
async def _astreaming_generator(self): | |
""" | |
Asynchronous generator to reduce I/O blocking when dealing with multiple | |
clients simultaneously. | |
This function returns a streaming generator which can manipulate partial output from an LLM | |
in-flight of the generation. This is useful in scenarios where the sales agent wants to take an action | |
before the full LLM output is available. For instance, if we want to do text to speech on the partial LLM output. | |
Returns | |
------- | |
AsyncGenerator | |
A streaming generator which can manipulate partial output from an LLM in-flight of the generation. | |
Examples | |
-------- | |
>>> streaming_generator = self._astreaming_generator() | |
>>> async for chunk in streaming_generator: | |
>>> await chunk ... | |
Out: Chunk 1, Chunk 2, ... etc. | |
See Also | |
-------- | |
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb | |
""" | |
messages = self._prep_messages() | |
return await self.acompletion_with_retry( | |
llm=self.sales_conversation_utterance_chain.llm, | |
messages=messages, | |
stop="<END_OF_TURN>", | |
stream=True, | |
model=self.model_name, | |
) | |
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Executes one step of the sales agent. | |
This function overrides the input temporarily with the current state of the conversation, | |
generates the agent's utterance using either the sales agent executor or the sales conversation utterance chain, | |
adds the agent's response to the conversation history, and returns the AI message. | |
Parameters | |
---------- | |
inputs : Dict[str, Any] | |
The initial inputs for the sales agent. | |
Returns | |
------- | |
Dict[str, Any] | |
The AI message generated by the sales agent. | |
""" | |
# override inputs temporarily | |
inputs = { | |
"input": "", | |
"conversation_stage": self.current_conversation_stage, | |
"conversation_history": "\n".join(self.conversation_history), | |
"salesperson_name": self.salesperson_name, | |
"salesperson_role": self.salesperson_role, | |
"company_name": self.company_name, | |
"company_business": self.company_business, | |
"company_values": self.company_values, | |
"conversation_purpose": self.conversation_purpose, | |
"conversation_type": self.conversation_type, | |
} | |
# Generate agent's utterance | |
if self.use_tools: | |
ai_message = self.sales_agent_executor.invoke(inputs) | |
output = ai_message["output"] | |
else: | |
ai_message = self.sales_conversation_utterance_chain.invoke( | |
inputs, return_intermediate_steps=True | |
) | |
output = ai_message["text"] | |
# Add agent's response to conversation history | |
agent_name = self.salesperson_name | |
output = agent_name + ": " + output | |
if "<END_OF_TURN>" not in output: | |
output += " <END_OF_TURN>" | |
self.conversation_history.append(output) | |
if self.verbose: | |
tool_status = "USE TOOLS INVOKE:" if self.use_tools else "WITHOUT TOOLS:" | |
print(f"{tool_status}\n#\n#\n#\n#\n------------------") | |
print(f"AI Message: {ai_message}") | |
print() | |
print(f"Output: {output.replace('<END_OF_TURN>', '')}") | |
return ai_message | |
@classmethod | |
@time_logger | |
def from_llm(cls, llm: ChatLiteLLM, verbose: bool = False, **kwargs) -> "SalesGPT": | |
""" | |
Class method to initialize the SalesGPT Controller from a given ChatLiteLLM instance. | |
This method sets up the stage analyzer chain and sales conversation utterance chain. It also checks if custom prompts | |
are to be used and if tools are to be set up for the agent. If tools are to be used, it sets up the knowledge base, | |
gets the tools, sets up the prompt, and initializes the agent with the tools. If tools are not to be used, it sets | |
the sales agent executor and knowledge base to None. | |
Parameters | |
---------- | |
llm : ChatLiteLLM | |
The ChatLiteLLM instance to initialize the SalesGPT Controller from. | |
verbose : bool, optional | |
If True, verbose output is enabled. Default is False. | |
\*\*kwargs : dict | |
Additional keyword arguments. | |
Returns | |
------- | |
SalesGPT | |
The initialized SalesGPT Controller. | |
""" | |
stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) | |
sales_conversation_utterance_chain = SalesConversationChain.from_llm( | |
llm, verbose=verbose | |
) | |
# Handle custom prompts | |
use_custom_prompt = kwargs.pop("use_custom_prompt", False) | |
custom_prompt = kwargs.pop("custom_prompt", None) | |
sales_conversation_utterance_chain = SalesConversationChain.from_llm( | |
llm, | |
verbose=verbose, | |
use_custom_prompt=use_custom_prompt, | |
custom_prompt=custom_prompt, | |
) | |
# Handle tools | |
use_tools_value = kwargs.pop("use_tools", False) | |
if isinstance(use_tools_value, str): | |
if use_tools_value.lower() not in ["true", "false"]: | |
raise ValueError("use_tools must be 'True', 'False', True, or False") | |
use_tools = use_tools_value.lower() == "true" | |
elif isinstance(use_tools_value, bool): | |
use_tools = use_tools_value | |
else: | |
raise ValueError( | |
"use_tools must be a boolean or a string ('True' or 'False')" | |
) | |
sales_agent_executor = None | |
knowledge_base = None | |
if use_tools: | |
product_catalog = kwargs.pop("product_catalog", None) | |
tools = get_tools(product_catalog) | |
prompt = CustomPromptTemplateForTools( | |
template=SALES_AGENT_TOOLS_PROMPT, | |
tools_getter=lambda x: tools, | |
input_variables=[ | |
"input", | |
"intermediate_steps", | |
"salesperson_name", | |
"salesperson_role", | |
"company_name", | |
"company_business", | |
"company_values", | |
"conversation_purpose", | |
"conversation_type", | |
"conversation_history", | |
], | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) | |
tool_names = [tool.name for tool in tools] | |
output_parser = SalesConvoOutputParser( | |
ai_prefix=kwargs.get("salesperson_name", ""), verbose=verbose | |
) | |
sales_agent_with_tools = LLMSingleActionAgent( | |
llm_chain=llm_chain, | |
output_parser=output_parser, | |
stop=["\nObservation:"], | |
allowed_tools=tool_names, | |
) | |
sales_agent_executor = CustomAgentExecutor.from_agent_and_tools( | |
agent=sales_agent_with_tools, | |
tools=tools, | |
verbose=verbose, | |
return_intermediate_steps=True, | |
) | |
return cls( | |
stage_analyzer_chain=stage_analyzer_chain, | |
sales_conversation_utterance_chain=sales_conversation_utterance_chain, | |
sales_agent_executor=sales_agent_executor, | |
knowledge_base=knowledge_base, | |
model_name=llm.model, | |
verbose=verbose, | |
use_tools=use_tools, | |
**kwargs, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain_community.chat_models import ChatLiteLLM | |
from salesgpt.logger import time_logger | |
from salesgpt.prompts import ( | |
SALES_AGENT_INCEPTION_PROMPT, | |
STAGE_ANALYZER_INCEPTION_PROMPT, | |
) | |
class StageAnalyzerChain(LLMChain): | |
"""Chain to analyze which conversation stage should the conversation move into.""" | |
@classmethod | |
@time_logger | |
def from_llm(cls, llm: ChatLiteLLM, verbose: bool = True) -> LLMChain: | |
"""Get the response parser.""" | |
stage_analyzer_inception_prompt_template = STAGE_ANALYZER_INCEPTION_PROMPT | |
prompt = PromptTemplate( | |
template=stage_analyzer_inception_prompt_template, | |
input_variables=[ | |
"conversation_history", | |
"conversation_stage_id", | |
"conversation_stages", | |
], | |
) | |
print(f"STAGE ANALYZER PROMPT {prompt}") | |
return cls(prompt=prompt, llm=llm, verbose=verbose) | |
class SalesConversationChain(LLMChain): | |
"""Chain to generate the next utterance for the conversation.""" | |
@classmethod | |
@time_logger | |
def from_llm( | |
cls, | |
llm: ChatLiteLLM, | |
verbose: bool = True, | |
use_custom_prompt: bool = False, | |
custom_prompt: str = "You are an AI Sales agent, sell me this pencil", | |
) -> LLMChain: | |
"""Get the response parser.""" | |
if use_custom_prompt: | |
sales_agent_inception_prompt = custom_prompt | |
prompt = PromptTemplate( | |
template=sales_agent_inception_prompt, | |
input_variables=[ | |
"salesperson_name", | |
"salesperson_role", | |
"company_name", | |
"company_business", | |
"company_values", | |
"conversation_purpose", | |
"conversation_type", | |
"conversation_history", | |
], | |
) | |
else: | |
sales_agent_inception_prompt = SALES_AGENT_INCEPTION_PROMPT | |
prompt = PromptTemplate( | |
template=sales_agent_inception_prompt, | |
input_variables=[ | |
"salesperson_name", | |
"salesperson_role", | |
"company_name", | |
"company_business", | |
"company_values", | |
"conversation_purpose", | |
"conversation_type", | |
"conversation_history", | |
], | |
) | |
return cls(prompt=prompt, llm=llm, verbose=verbose) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Corrected import statements | |
import inspect | |
from typing import Any, Dict, Optional | |
# Corrected import path for RunnableConfig | |
from langchain.agents import AgentExecutor | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.chains.base import Chain | |
from langchain_core.load.dump import dumpd | |
from langchain_core.outputs import RunInfo | |
from langchain_core.runnables import RunnableConfig, ensure_config | |
class CustomAgentExecutor(AgentExecutor): | |
def invoke( | |
self, | |
input: Dict[str, Any], | |
config: Optional[RunnableConfig] = None, | |
**kwargs: Any, | |
) -> Dict[str, Any]: | |
intermediate_steps = [] # Initialize the list to capture intermediate steps | |
# Ensure the configuration is set up correctly | |
config = ensure_config(config) | |
callbacks = config.get("callbacks") | |
tags = config.get("tags") | |
metadata = config.get("metadata") | |
run_name = config.get("run_name") | |
include_run_info = kwargs.get("include_run_info", False) | |
return_only_outputs = kwargs.get("return_only_outputs", False) | |
# Prepare inputs based on the provided input | |
inputs = self.prep_inputs(input) | |
callback_manager = CallbackManager.configure( | |
callbacks, | |
self.callbacks, | |
self.verbose, | |
tags, | |
self.tags, | |
metadata, | |
self.metadata, | |
) | |
# Check if the _call method supports the new argument 'run_manager' | |
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") | |
run_manager = callback_manager.on_chain_start( | |
dumpd(self), | |
inputs, | |
name=run_name, | |
) | |
# Capture the start of the chain as an intermediate step | |
intermediate_steps.append( | |
{"event": "Chain Started", "details": "Inputs prepared"} | |
) | |
try: | |
# Execute the _call method, passing 'run_manager' if supported | |
outputs = ( | |
self._call(inputs, run_manager=run_manager) | |
if new_arg_supported | |
else self._call(inputs) | |
) | |
# Capture a successful call as an intermediate step | |
intermediate_steps.append({"event": "Call Successful", "outputs": outputs}) | |
except BaseException as e: | |
# Handle errors and capture them as intermediate steps | |
run_manager.on_chain_error(e) | |
intermediate_steps.append({"event": "Error", "error": str(e)}) | |
raise e | |
finally: | |
# Mark the end of the chain execution | |
run_manager.on_chain_end(outputs) | |
# Prepare the final outputs, including run information if requested | |
final_outputs: Dict[str, Any] = self.prep_outputs( | |
inputs, outputs, return_only_outputs | |
) | |
if include_run_info: | |
final_outputs["run_info"] = RunInfo(run_id=run_manager.run_id) | |
# Include intermediate steps in the final outputs | |
final_outputs["intermediate_steps"] = intermediate_steps | |
return final_outputs | |
if __name__ == "__main__": | |
agent = CustomAgentExecutor() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
from functools import wraps | |
logger = logging.getLogger(__name__) | |
stream_handler = logging.StreamHandler() | |
log_filename = "output.log" | |
file_handler = logging.FileHandler(filename=log_filename) | |
handlers = [stream_handler, file_handler] | |
class TimeFilter(logging.Filter): | |
def filter(self, record): | |
return "Running" in record.getMessage() | |
logger.addFilter(TimeFilter()) | |
# Configure the logging module | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(name)s %(asctime)s - %(levelname)s - %(message)s", | |
handlers=handlers, | |
) | |
def time_logger(func): | |
""" | |
Decorator function to log the time taken by any function. | |
This decorator logs the execution time of the decorated function. It logs the start time before the function | |
execution, the end time after the function execution, and calculates the execution time. The function name and | |
execution time are then logged at the INFO level. | |
Args: | |
func (Callable): The function to be decorated. | |
Returns: | |
Callable: The decorated function. | |
""" | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start_time = time.time() # Start time before function execution | |
result = func(*args, **kwargs) # Function execution | |
end_time = time.time() # End time after function execution | |
execution_time = end_time - start_time # Calculate execution time | |
logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---") | |
return result | |
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models import BaseChatModel, SimpleChatModel | |
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.runnables import run_in_executor | |
from langchain_openai import ChatOpenAI | |
from salesgpt.tools import completion_bedrock | |
class BedrockCustomModel(ChatOpenAI): | |
"""A custom chat model that echoes the first `n` characters of the input. | |
When contributing an implementation to LangChain, carefully document | |
the model including the initialization parameters, include | |
an example of how to initialize the model and include any relevant | |
links to the underlying models documentation or API. | |
Example: | |
.. code-block:: python | |
model = CustomChatModel(n=2) | |
result = model.invoke([HumanMessage(content="hello")]) | |
result = model.batch([[HumanMessage(content="hello")], | |
[HumanMessage(content="world")]]) | |
""" | |
model: str | |
system_prompt: str | |
"""The number of characters from the last message of the prompt to be echoed.""" | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
"""Override the _generate method to implement the chat model logic. | |
This can be a call to an API, a call to a local model, or any other | |
implementation that generates a response to the input prompt. | |
Args: | |
messages: the prompt composed of a list of messages. | |
stop: a list of strings on which the model should stop generating. | |
If generation stops due to a stop token, the stop token itself | |
SHOULD BE INCLUDED as part of the output. This is not enforced | |
across models right now, but it's a good practice to follow since | |
it makes it much easier to parse the output of the model | |
downstream and understand why generation stopped. | |
run_manager: A run manager with callbacks for the LLM. | |
""" | |
last_message = messages[-1] | |
print(messages) | |
response = completion_bedrock( | |
model_id=self.model, | |
system_prompt=self.system_prompt, | |
messages=[{"content": last_message.content, "role": "user"}], | |
max_tokens=1000, | |
) | |
print("output", response) | |
content = response["content"][0]["text"] | |
message = AIMessage(content=content) | |
generation = ChatGeneration(message=message) | |
return ChatResult(generations=[generation]) | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
stream: Optional[bool] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
should_stream = stream if stream is not None else self.streaming | |
if should_stream: | |
raise NotImplementedError("Streaming not implemented") | |
last_message = messages[-1] | |
print(messages) | |
response = await acompletion_bedrock( | |
model_id=self.model, | |
system_prompt=self.system_prompt, | |
messages=[{"content": last_message.content, "role": "user"}], | |
max_tokens=1000, | |
) | |
print("output", response) | |
content = response["content"][0]["text"] | |
message = AIMessage(content=content) | |
generation = ChatGeneration(message=message) | |
return ChatResult(generations=[generation]) | |
# message_dicts, params = self._create_message_dicts(messages, stop) | |
# params = { | |
# **params, | |
# **({"stream": stream} if stream is not None else {}), | |
# **kwargs, | |
# } | |
# response = await self.async_client.create(messages=message_dicts, **params) | |
# return self._create_chat_result(response) | |
import aioboto3 | |
import os | |
import json | |
async def acompletion_bedrock(model_id, system_prompt, messages, max_tokens=1000): | |
""" | |
High-level API call to generate a message with Anthropic Claude, refactored for async. | |
""" | |
session = aioboto3.Session() | |
async with session.client(service_name="bedrock-runtime", region_name=os.environ.get("AWS_REGION_NAME")) as bedrock_runtime: | |
body = json.dumps( | |
{ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": max_tokens, | |
"system": system_prompt, | |
"messages": messages, | |
} | |
) | |
response = await bedrock_runtime.invoke_model(body=body, modelId=model_id) | |
# print('RESPONSE', response) | |
# Correctly handle the streaming body | |
response_body_bytes = await response['body'].read() | |
# print('RESPONSE BODY', response_body_bytes) | |
response_body = json.loads(response_body_bytes.decode("utf-8")) | |
# print('RESPONSE BODY', response_body) | |
return response_body | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from typing import Union | |
from langchain.agents.agent import AgentOutputParser | |
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS | |
from langchain.schema import AgentAction, AgentFinish # OutputParserException | |
class SalesConvoOutputParser(AgentOutputParser): | |
ai_prefix: str = "AI" # change for salesperson_name | |
verbose: bool = False | |
def get_format_instructions(self) -> str: | |
return FORMAT_INSTRUCTIONS | |
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: | |
if self.verbose: | |
print("TEXT") | |
print(text) | |
print("-------") | |
regex = r"Action: (.*?)[\n]*Action Input: (.*)" | |
match = re.search(regex, text) | |
if not match: | |
return AgentFinish( | |
{"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text | |
) | |
action = match.group(1) | |
action_input = match.group(2) | |
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) | |
@property | |
def _type(self) -> str: | |
return "sales-agent" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SALES_AGENT_TOOLS_PROMPT = """ | |
Never forget your name is {salesperson_name}. You work as a {salesperson_role}. | |
You work at company named {company_name}. {company_name}'s business is the following: {company_business}. | |
Company values are the following. {company_values} | |
You are contacting a potential prospect in order to {conversation_purpose} | |
Your means of contacting the prospect is {conversation_type} | |
If you're asked about where you got the user's contact information, say that you got it from public records. | |
Keep your responses in short length to retain the user's attention. Never produce lists, just answers. | |
Start the conversation by just a greeting and how is the prospect doing without pitching in your first turn. | |
When the conversation is over, output <END_OF_CALL> | |
Always think about at which conversation stage you are at before answering: | |
1: Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are calling. | |
2: Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions. | |
3: Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors. | |
4: Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes. | |
5: Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points. | |
6: Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims. | |
7: Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits. | |
8: End conversation: The prospect has to leave to call, the prospect is not interested, or next steps where already determined by the sales agent. | |
TOOLS: | |
------ | |
{salesperson_name} has access to the following tools: | |
{tools} | |
To use a tool, please use the following format: | |
``` | |
Thought: Do I need to use a tool? Yes | |
Action: the action to take, should be one of {tools} | |
Action Input: the input to the action, always a simple string input | |
Observation: the result of the action | |
``` | |
If the result of the action is "I don't know." or "Sorry I don't know", then you have to say that to the user as described in the next sentence. | |
When you have a response to say to the Human, or if you do not need to use a tool, or if tool did not help, you MUST use the format: | |
``` | |
Thought: Do I need to use a tool? No | |
{salesperson_name}: [your response here, if previously used a tool, rephrase latest observation, if unable to find the answer, say it] | |
``` | |
You must respond according to the previous conversation history and the stage of the conversation you are at. | |
Only generate one response at a time and act as {salesperson_name} only! | |
Begin! | |
Previous conversation history: | |
{conversation_history} | |
Thought: | |
{agent_scratchpad} | |
""" | |
SALES_AGENT_INCEPTION_PROMPT = """Never forget your name is {salesperson_name}. You work as a {salesperson_role}. | |
You work at company named {company_name}. {company_name}'s business is the following: {company_business}. | |
Company values are the following. {company_values} | |
You are contacting a potential prospect in order to {conversation_purpose} | |
Your means of contacting the prospect is {conversation_type} | |
If you're asked about where you got the user's contact information, say that you got it from public records. | |
Keep your responses in short length to retain the user's attention. Never produce lists, just answers. | |
Start the conversation by just a greeting and how is the prospect doing without pitching in your first turn. | |
When the conversation is over, output <END_OF_CALL> | |
Always think about at which conversation stage you are at before answering: | |
1: Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are calling. | |
2: Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions. | |
3: Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors. | |
4: Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes. | |
5: Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points. | |
6: Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims. | |
7: Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits. | |
8: End conversation: The prospect has to leave to call, the prospect is not interested, or next steps where already determined by the sales agent. | |
Example 1: | |
Conversation history: | |
{salesperson_name}: Hey, good morning! <END_OF_TURN> | |
User: Hello, who is this? <END_OF_TURN> | |
{salesperson_name}: This is {salesperson_name} calling from {company_name}. How are you? | |
User: I am well, why are you calling? <END_OF_TURN> | |
{salesperson_name}: I am calling to talk about options for your home insurance. <END_OF_TURN> | |
User: I am not interested, thanks. <END_OF_TURN> | |
{salesperson_name}: Alright, no worries, have a good day! <END_OF_TURN> <END_OF_CALL> | |
End of example 1. | |
You must respond according to the previous conversation history and the stage of the conversation you are at. | |
Only generate one response at a time and act as {salesperson_name} only! When you are done generating, end with '<END_OF_TURN>' to give the user a chance to respond. | |
Conversation history: | |
{conversation_history} | |
{salesperson_name}:""" | |
STAGE_ANALYZER_INCEPTION_PROMPT = """ | |
You are a sales assistant helping your sales agent to determine which stage of a sales conversation should the agent stay at or move to when talking to a user. | |
Start of conversation history: | |
=== | |
{conversation_history} | |
=== | |
End of conversation history. | |
Current Conversation stage is: {conversation_stage_id} | |
Now determine what should be the next immediate conversation stage for the agent in the sales conversation by selecting only from the following options: | |
{conversation_stages} | |
The answer needs to be one number only from the conversation stages, no words. | |
Only use the current conversation stage and conversation history to determine your answer! | |
If the conversation history is empty, always start with Introduction! | |
If you think you should stay in the same conversation stage until user gives more input, just output the current conversation stage. | |
Do not answer anything else nor add anything to you answer.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import logging | |
import os | |
import warnings | |
from dotenv import load_dotenv | |
from langchain_community.chat_models import ChatLiteLLM | |
from salesgpt.agents import SalesGPT | |
load_dotenv() # loads .env file | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Suppress logging | |
logging.getLogger().setLevel(logging.CRITICAL) | |
# LangSmith settings section, set TRACING_V2 to "true" to enable it | |
# or leave it as it is, if you don't need tracing (more info in README) | |
os.environ["LANGCHAIN_TRACING_V2"] = "false" | |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_SMITH_API_KEY", "") | |
os.environ["LANGCHAIN_PROJECT"] = "" # insert you project name here | |
if __name__ == "__main__": | |
# Initialize argparse | |
parser = argparse.ArgumentParser(description="Description of your program") | |
# Add arguments | |
parser.add_argument( | |
"--config", type=str, help="Path to agent config file", default="" | |
) | |
parser.add_argument( | |
"--verbose", action="store_true", help="Verbosity", default=False | |
) | |
parser.add_argument( | |
"--max_num_turns", | |
type=int, | |
help="Maximum number of turns in the sales conversation", | |
default=10, | |
) | |
# Parse arguments | |
args = parser.parse_args() | |
# Access arguments | |
config_path = args.config | |
verbose = args.verbose | |
max_num_turns = args.max_num_turns | |
llm = ChatLiteLLM(temperature=0.2, model_name="gpt-3.5-turbo") | |
if config_path == "": | |
print("No agent config specified, using a standard config") | |
# keep boolean as string to be consistent with JSON configs. | |
USE_TOOLS = True | |
sales_agent_kwargs = { | |
"verbose": verbose, | |
"use_tools": USE_TOOLS, | |
} | |
if USE_TOOLS: | |
sales_agent_kwargs.update( | |
{ | |
# "product_catalog": "examples/sample_product_catalog.txt", | |
"salesperson_name": "Jenna" | |
} | |
) | |
sales_agent = SalesGPT.from_llm(llm, **sales_agent_kwargs) | |
else: | |
try: | |
with open(config_path, "r", encoding="UTF-8") as f: | |
config = json.load(f) | |
except FileNotFoundError: | |
print(f"Config file {config_path} not found.") | |
exit(1) | |
except json.JSONDecodeError: | |
print(f"Error decoding JSON from the config file {config_path}.") | |
exit(1) | |
print(f"Agent config {config}") | |
sales_agent = SalesGPT.from_llm(llm, verbose=verbose, **config) | |
sales_agent.seed_agent() | |
print("=" * 10) | |
cnt = 0 | |
while cnt != max_num_turns: | |
cnt += 1 | |
if cnt == max_num_turns: | |
print("Maximum number of turns reached - ending the conversation.") | |
break | |
sales_agent.step() | |
# end conversation | |
if "<END_OF_CALL>" in sales_agent.conversation_history[-1]: | |
print("Sales Agent determined it is time to end the conversation.") | |
break | |
human_input = input("Your response: ") | |
sales_agent.human_step(human_input) | |
print("=" * 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
import re | |
from langchain_community.chat_models import BedrockChat, ChatLiteLLM | |
from langchain_openai import ChatOpenAI | |
from salesgpt.agents import SalesGPT | |
from salesgpt.models import BedrockCustomModel | |
class SalesGPTAPI: | |
def __init__( | |
self, | |
config_path: str, | |
verbose: bool = True, | |
max_num_turns: int = 20, | |
model_name: str = "gpt-3.5-turbo", | |
product_catalog: str = "examples/sample_product_catalog.txt", | |
use_tools=True, | |
): | |
self.config_path = config_path | |
self.verbose = verbose | |
self.max_num_turns = max_num_turns | |
self.model_name = model_name | |
if "anthropic" in model_name: | |
self.llm = BedrockCustomModel( | |
type="bedrock-model", | |
model=model_name, | |
system_prompt="You are a helpful assistant.", | |
) | |
else: | |
self.llm = ChatLiteLLM(temperature=0.2, model=model_name) | |
self.product_catalog = product_catalog | |
self.conversation_history = [] | |
self.use_tools = use_tools | |
self.sales_agent = self.initialize_agent() | |
self.current_turn = 0 | |
def initialize_agent(self): | |
config = {"verbose": self.verbose} | |
if self.config_path: | |
with open(self.config_path, "r") as f: | |
config.update(json.load(f)) | |
if self.verbose: | |
print(f"Loaded agent config: {config}") | |
else: | |
print("Default agent config in use") | |
if self.use_tools: | |
print("USING TOOLS") | |
config.update( | |
{ | |
"use_tools": True, | |
"product_catalog": self.product_catalog, | |
"salesperson_name": "Ted Lasso" | |
if not self.config_path | |
else config.get("salesperson_name", "Ted Lasso"), | |
} | |
) | |
sales_agent = SalesGPT.from_llm(self.llm, **config) | |
print(f"SalesGPT use_tools: {sales_agent.use_tools}") | |
sales_agent.seed_agent() | |
return sales_agent | |
async def do(self, human_input=None): | |
self.current_turn += 1 | |
current_turns = self.current_turn | |
if current_turns >= self.max_num_turns: | |
print("Maximum number of turns reached - ending the conversation.") | |
return [ | |
"BOT", | |
"In case you'll have any questions - just text me one more time!", | |
] | |
if human_input is not None: | |
self.sales_agent.human_step(human_input) | |
ai_log = await self.sales_agent.astep(stream=False) | |
await self.sales_agent.adetermine_conversation_stage() | |
# TODO - handle end of conversation in the API - send a special token to the client? | |
if self.verbose: | |
print("=" * 10) | |
print(f"AI LOG {ai_log}") | |
if ( | |
self.sales_agent.conversation_history | |
and "<END_OF_CALL>" in self.sales_agent.conversation_history[-1] | |
): | |
print("Sales Agent determined it is time to end the conversation.") | |
# strip end of call for now | |
self.sales_agent.conversation_history[ | |
-1 | |
] = self.sales_agent.conversation_history[-1].replace("<END_OF_CALL>", "") | |
reply = ( | |
self.sales_agent.conversation_history[-1] | |
if self.sales_agent.conversation_history | |
else "" | |
) | |
#print("AI LOG INTERMEDIATE STEPS: ", ai_log["intermediate_steps"]) | |
if ( | |
self.use_tools and | |
"intermediate_steps" in ai_log and | |
len(ai_log["intermediate_steps"]) > 0 | |
): | |
try: | |
res_str = ai_log["intermediate_steps"][0] | |
print("RES STR: ", res_str) | |
agent_action = res_str[0] | |
tool, tool_input, log = ( | |
agent_action.tool, | |
agent_action.tool_input, | |
agent_action.log, | |
) | |
actions = re.search(r"Action: (.*?)[\n]*Action Input: (.*)", log) | |
action_input = actions.group(2) | |
action_output = res_str[1] | |
except Exception as e: | |
print("ERROR: ", e) | |
tool, tool_input, action, action_input, action_output = ( | |
"", | |
"", | |
"", | |
"", | |
"", | |
) | |
else: | |
tool, tool_input, action, action_input, action_output = "", "", "", "", "" | |
print(reply) | |
payload = { | |
"bot_name": reply.split(": ")[0], | |
"response": ": ".join(reply.split(": ")[1:]).rstrip("<END_OF_TURN>"), | |
"conversational_stage": self.sales_agent.current_conversation_stage, | |
"tool": tool, | |
"tool_input": tool_input, | |
"action_output": action_output, | |
"action_input": action_input, | |
"model_name": self.model_name, | |
} | |
return payload | |
async def do_stream(self, conversation_history: [str], human_input=None): | |
# TODO | |
current_turns = len(conversation_history) + 1 | |
if current_turns >= self.max_num_turns: | |
print("Maximum number of turns reached - ending the conversation.") | |
yield [ | |
"BOT", | |
"In case you'll have any questions - just text me one more time!", | |
] | |
raise StopAsyncIteration | |
self.sales_agent.seed_agent() | |
self.sales_agent.conversation_history = conversation_history | |
if human_input is not None: | |
self.sales_agent.human_step(human_input) | |
stream_gen = self.sales_agent.astep(stream=True) | |
for model_response in stream_gen: | |
for choice in model_response.choices: | |
message = choice["delta"]["content"] | |
if message is not None: | |
if "<END_OF_CALL>" in message: | |
print( | |
"Sales Agent determined it is time to end the conversation." | |
) | |
yield [ | |
"BOT", | |
"In case you'll have any questions - just text me one more time!", | |
] | |
yield message | |
else: | |
continue |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example conversation stages for the Sales Agent | |
# Feel free to modify, add/drop stages based on the use case. | |
CONVERSATION_STAGES = { | |
"1": "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are calling.", | |
"2": "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.", | |
"3": "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.", | |
"4": "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.", | |
"5": "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.", | |
"6": "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.", | |
"7": "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits.", | |
"8": "End conversation: It's time to end the call as there is nothing else to be said.", | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable | |
from langchain.prompts.base import StringPromptTemplate | |
class CustomPromptTemplateForTools(StringPromptTemplate): | |
# The template to use | |
template: str | |
############## NEW ###################### | |
# The list of tools available | |
tools_getter: Callable | |
def format(self, **kwargs) -> str: | |
# Get the intermediate steps (AgentAction, Observation tuples) | |
# Format them in a particular way | |
intermediate_steps = kwargs.pop("intermediate_steps") | |
thoughts = "" | |
for action, observation in intermediate_steps: | |
thoughts += action.log | |
thoughts += f"\nObservation: {observation}\nThought: " | |
# Set the agent_scratchpad variable to that value | |
kwargs["agent_scratchpad"] = thoughts | |
############## NEW ###################### | |
tools = self.tools_getter(kwargs["input"]) | |
# Create a tools variable from the list of tools provided | |
kwargs["tools"] = "\n".join( | |
[f"{tool.name}: {tool.description}" for tool in tools] | |
) | |
# Create a list of tool names for the tools provided | |
kwargs["tool_names"] = ", ".join([tool.name for tool in tools]) | |
return self.template.format(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import boto3 | |
import requests | |
from langchain.agents import Tool | |
from langchain.chains import RetrievalQA | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.chat_models import BedrockChat | |
from langchain_community.vectorstores import Chroma | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from litellm import completion | |
def setup_knowledge_base( | |
product_catalog: str = None, model_name: str = "gpt-3.5-turbo" | |
): | |
""" | |
We assume that the product catalog is simply a text string. | |
""" | |
# load product catalog | |
with open(product_catalog, "r") as f: | |
product_catalog = f.read() | |
text_splitter = CharacterTextSplitter(chunk_size=10, chunk_overlap=0) | |
texts = text_splitter.split_text(product_catalog) | |
llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0) | |
embeddings = OpenAIEmbeddings() | |
docsearch = Chroma.from_texts( | |
texts, embeddings, collection_name="product-knowledge-base" | |
) | |
knowledge_base = RetrievalQA.from_chain_type( | |
llm=llm, chain_type="stuff", retriever=docsearch.as_retriever() | |
) | |
return knowledge_base | |
def completion_bedrock(model_id, system_prompt, messages, max_tokens=1000): | |
""" | |
High-level API call to generate a message with Anthropic Claude. | |
""" | |
bedrock_runtime = boto3.client( | |
service_name="bedrock-runtime", region_name=os.environ.get("AWS_REGION_NAME") | |
) | |
body = json.dumps( | |
{ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": max_tokens, | |
"system": system_prompt, | |
"messages": messages, | |
} | |
) | |
response = bedrock_runtime.invoke_model(body=body, modelId=model_id) | |
response_body = json.loads(response.get("body").read()) | |
return response_body | |
def get_product_id_from_query(query, product_price_id_mapping_path): | |
# Load product_price_id_mapping from a JSON file | |
with open(product_price_id_mapping_path, "r") as f: | |
product_price_id_mapping = json.load(f) | |
# Serialize the product_price_id_mapping to a JSON string for inclusion in the prompt | |
product_price_id_mapping_json_str = json.dumps(product_price_id_mapping) | |
# Dynamically create the enum list from product_price_id_mapping keys | |
enum_list = list(product_price_id_mapping.values()) + [ | |
"No relevant product id found" | |
] | |
enum_list_str = json.dumps(enum_list) | |
prompt = f""" | |
You are an expert data scientist and you are working on a project to recommend products to customers based on their needs. | |
Given the following query: | |
{query} | |
and the following product price id mapping: | |
{product_price_id_mapping_json_str} | |
return the price id that is most relevant to the query. | |
ONLY return the price id, no other text. If no relevant price id is found, return 'No relevant price id found'. | |
Your output will follow this schema: | |
{{ | |
"$schema": "http://json-schema.org/draft-07/schema#", | |
"title": "Price ID Response", | |
"type": "object", | |
"properties": {{ | |
"price_id": {{ | |
"type": "string", | |
"enum": {enum_list_str} | |
}} | |
}}, | |
"required": ["price_id"] | |
}} | |
Return a valid directly parsable json, dont return in it within a code snippet or add any kind of explanation!! | |
""" | |
prompt += "{" | |
model_name = os.getenv("GPT_MODEL", "gpt-3.5-turbo-1106") | |
if "anthropic" in model_name: | |
response = completion_bedrock( | |
model_id=model_name, | |
system_prompt="You are a helpful assistant.", | |
messages=[{"content": prompt, "role": "user"}], | |
max_tokens=1000, | |
) | |
product_id = response["content"][0]["text"] | |
else: | |
response = completion( | |
model=model_name, | |
messages=[{"content": prompt, "role": "user"}], | |
max_tokens=1000, | |
temperature=0, | |
) | |
product_id = response.choices[0].message.content.strip() | |
return product_id | |
def generate_stripe_payment_link(query: str) -> str: | |
"""Generate a stripe payment link for a customer based on a single query string.""" | |
# example testing payment gateway url | |
PAYMENT_GATEWAY_URL = os.getenv( | |
"PAYMENT_GATEWAY_URL", "https://agent-payments-gateway.vercel.app/payment" | |
) | |
PRODUCT_PRICE_MAPPING = os.getenv( | |
"PRODUCT_PRICE_MAPPING", "example_product_price_id_mapping.json" | |
) | |
# use LLM to get the price_id from query | |
price_id = get_product_id_from_query(query, PRODUCT_PRICE_MAPPING) | |
price_id = json.loads(price_id) | |
payload = json.dumps( | |
{"prompt": query, **price_id, "stripe_key": os.getenv("STRIPE_API_KEY")} | |
) | |
headers = { | |
"Content-Type": "application/json", | |
} | |
response = requests.request( | |
"POST", PAYMENT_GATEWAY_URL, headers=headers, data=payload | |
) | |
return response.text | |
def get_tools(product_catalog): | |
# query to get_tools can be used to be embedded and relevant tools found | |
# see here: https://langchain-langchain.vercel.app/docs/use_cases/agents/custom_agent_with_plugin_retrieval#tool-retriever | |
# we only use two tools for now, but this is highly extensible! | |
knowledge_base = setup_knowledge_base(product_catalog) | |
tools = [ | |
Tool( | |
name="ProductSearch", | |
func=knowledge_base.run, | |
description="useful for when you need to answer questions about product information or services offered, availability and their costs.", | |
), | |
Tool( | |
name="GeneratePaymentLink", | |
func=generate_stripe_payment_link, | |
description="useful to close a transaction with a customer. You need to include product name and quantity and customer name in the query input.", | |
), | |
] | |
return tools |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment