Skip to content

Instantly share code, notes, and snippets.

@mrchief
Last active June 20, 2024 18:48
Show Gist options
  • Save mrchief/ebb2cb8104800df3e06005104474e8d7 to your computer and use it in GitHub Desktop.
Save mrchief/ebb2cb8104800df3e06005104474e8d7 to your computer and use it in GitHub Desktop.
""" Chains and agents """
import os
import re
from typing import Optional
import pandas as pd
from langchain.agents import AgentExecutor, ZeroShotAgent
from langchain.agents.agent_types import AgentType
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.chains.llm import LLMChain
from langchain.tools import StructuredTool, Tool
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import (
ChatPromptTemplate,
FewShotPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_experimental.agents.agent_toolkits import create_csv_agent
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from openai import OpenAI as Openai
from pydantic import BaseModel, Field
from src.core.cloud_storage import RAW, pull_project_folder
from src.core.rag import IndexLLM
from src.utils.agents import (
AGENT_SYSTEM_PREFIX,
CUSTOM_AGENT_FORMAT,
CUSTOM_PREFIX,
CUSTOM_SUFFIX,
apologize,
greets,
)
from src.utils.examples_creation import (
build_examples_from_dataframe,
get_key_from_file_path,
)
from src.utils.settings import OPENAI_API_KEY, RAW_PATH, configure_logger
logger = configure_logger("Chain agent")
class StructuredFileAgentInput(BaseModel):
question: str = Field()
class AgentSynthesizer:
"""Main class for agent synthesizer"""
def __init__(self, index: IndexLLM):
"""Init
Args:
index (IndexLLM): An IndexLLM object that
contains llamatool.
"""
self.index = index
self.llama_tool = None
self.agent_chain = None
self.memory = None
self.set_default_tools()
def set_default_tools(self):
self.greetings_tool = StructuredTool.from_function(
name="Greets",
func=greets,
description="use this if there is no question to answer, invite the user to ask a question",
args_schema=BaseModel,
return_direct=True,
)
self.apologize_tool = StructuredTool.from_function(
name="Apologize",
func=apologize,
description="use this if the other tools are not suitable to respond a question",
args_schema=BaseModel,
return_direct=True,
)
def load_index_if_not_loaded(
self, path: Optional[str], index: Optional[IndexLLM]
) -> None:
"""Load index if it is not loaded
Args:
path (Optional[str]): The path of the project index
index (Optional[IndexLLM]): _description_
Raises:
err: Exception at loading the index
"""
if path is not None:
project_name = os.path.basename(path)
else:
project_name = ""
if index is not None:
if hasattr(index, "index") and index.index_id == project_name:
logger.info(f"Index {index.index_id} is already set")
else:
try:
self.index = index
index.load_index(path)
logger.info(f"Loaded index: {index.index_id}")
except Exception as err:
logger.error(
f"Failed to load index from path: {path}", exc_info=err
)
raise err
else:
try:
self.index.load_index(path=path)
logger.info(f"Loaded index: {self.index.index_id}")
except Exception as err:
logger.error(f"Failed to load index from path: {path}", exc_info=err)
raise err
self._download_raw_from_project_name(project_name)
files_dict = self.index.get_docs_filename_and_context_from_metadata()
self.structured_filenames = list(files_dict.keys())
self.structured_user_contexts = list(files_dict.values())
self.llama_tool = self.index.create_agent_tool(
name="Read",
description="use this to search for information using private documents and retrieve augmented information from there",
return_redirect=True,
)
def _download_raw_from_project_name(self, name: str):
"""Download raw data from project name
Args:
name (str): The name of the project (index) to be downloaded
"""
raw_path = RAW_PATH + name
if not os.path.exists(raw_path):
os.mkdir(path=raw_path)
logger.info(f"Pulling raw data for {name}")
pull_project_folder(RAW, name, raw_path)
else:
logger.info(f"Already loaded {RAW} content into {raw_path}")
def get_metadata_by_filetype(self, extension: str = "csv") -> tuple:
"""Get metadata by file type
Args:
extension (str, optional): The file extension. Defaults to "csv".
Returns:
tuple: A tuple containing the filenames and the user contexts
"""
if len(self.structured_filenames) != len(self.structured_user_contexts):
raise ValueError("Files and user contexts are not the same length")
filenames = []
u_contexts = []
for i, filename in enumerate(self.structured_filenames):
if filename.endswith(extension):
filenames.append(filename)
u_contexts.append(self.structured_user_contexts[i])
return filenames, u_contexts
def resolve_which_is_the_relevant_file(self, question: str) -> int:
"""Resolve which is the most relevant file to answer a given question
it uses two lists in the agent object:
- structured_filenames
- structured_user_contexts
Args:
question (str): The question to be answered
Returns:
int: The index of the most relevant file
"""
str_contexts = ""
i = 0
filenames, contexts = self.get_metadata_by_filetype()
logger.info(
f"Resolving the most relevant file using contexts from {len(filenames)} files"
)
for context in contexts:
str_contexts += f"""'{context}' is the context of the file number {i}, """
i += 1
client = Openai(api_key=OPENAI_API_KEY)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
temperature=0.01,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"Taking into consideration files tagged with an ordinal number and with the following contexts provided by a human: '{str_contexts}'",
},
{
"role": "user",
"content": f"Which file is the most relevant to answer to the question '{question}'? To respond, provide only the number of the most relevant file",
},
],
)
response_message = response.choices[0].message.content
logger.info(f"original message: {response_message}")
return int(re.findall(r"\d+", response_message)[0])
def _define_few_shot_examples(self, file_path: str, question: str) -> str:
"""Define few shot examples for an specific question using a file
Args:
file_path (str): File path to the csv file
question (str): Question from the user
Returns:
str: A prompt prefix with the few shot examples
"""
table_df = pd.read_csv(file_path)
file_key = get_key_from_file_path(file_path)
examples = build_examples_from_dataframe(df=table_df, key=file_key)
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples=examples,
embeddings=OpenAIEmbeddings(),
vectorstore_cls=FAISS,
k=2,
input_keys=["input"],
)
system_prefix = AGENT_SYSTEM_PREFIX
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=PromptTemplate.from_template(
"User input: {input}\nAgent output:{output}"
),
input_variables=["input", "output"],
prefix=system_prefix,
suffix="User input: {input}\n",
)
full_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=few_shot_prompt),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
prompt_val = full_prompt.invoke(
{
"input": question,
"top_k": 2,
"dialect": "python_REPL_ast",
"agent_scratchpad": [],
}
)
return prompt_val.to_string()
def create_structured_agent(self, question: str):
"""Create a structured agent from the index
Args:
question (str): A question from the user
Returns:
An agent object
"""
file_number = self.resolve_which_is_the_relevant_file(question=question)
csv_files, csv_user_contexts = self.get_metadata_by_filetype()
file = csv_files[file_number]
logger.info(
f"The file number is {file_number} using it from "
+ f"{len(self.structured_filenames)} files : "
+ f"[{csv_files[file_number]} : {csv_user_contexts[file_number]}]"
)
raw_path = os.path.join(RAW_PATH, self.index.index_id)
prompt_val = self._define_few_shot_examples(
file_path=os.path.join(raw_path, file), question=question
)
agent_csv = create_csv_agent(
ChatOpenAI(temperature=0.01, model="gpt-3.5-turbo"),
os.path.join(raw_path, file),
verbose=False,
agent_type=AgentType.OPENAI_FUNCTIONS,
prefix=prompt_val,
)
return agent_csv
def initialize_memory_conversation(self) -> ConversationBufferMemory:
"""Initialize memory conversation
Returns:
ConversationBufferMemory: A buffer memory for a conversation
"""
self._config_memory_conversation()
return self.memory
def query_agent(self, input_text: str) -> tuple:
"""Run a query within the agent
Args:
input_text (str): A query text from the user
Returns:
tuple: Both, the output and the intermediate steps
from the response object
"""
if self.agent_chain is None:
self._config_memory_conversation(question=input_text)
response = self.agent_chain({"input": input_text})
logger.info(f"Memory: {self.memory.chat_memory}")
return (
response["output"],
response["intermediate_steps"],
)
def _enlist_tools(self, *args) -> None:
"""Set tools to be used"""
tools = list(args)
self.tools = tools
def _config_memory_conversation(self, question: str):
"""Configure conversation memory"""
self._enlist_tools(
self.greetings_tool,
self.apologize_tool,
self.llama_tool,
Tool.from_function(
func=self.create_structured_agent(question=question).invoke,
name="Query table",
description="useful for when you require to get specific data running queries for data sources that contain metrics in a tabular structure",
),
)
prompt = ZeroShotAgent.create_prompt(
tools=self.tools,
prefix=CUSTOM_PREFIX,
suffix=CUSTOM_SUFFIX,
format_instructions=CUSTOM_AGENT_FORMAT,
input_variables=["input", "chat_history", "agent_scratchpad"],
)
self.memory = ConversationBufferMemory(
memory_key="chat_history", output_key="output"
)
llm = ChatOpenAI(
temperature=0.01, openai_api_key=OPENAI_API_KEY, model=self.index.llm.model
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=self.tools)
self.agent_chain = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=self.tools,
memory=self.memory,
handle_parsing_errors=True,
return_intermediate_steps=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment