Created
February 16, 2024 16:30
-
-
Save shikanime/1f87e791d1ffc7f6247350b1105176d3 to your computer and use it in GitHub Desktop.
Simple Retrieval using LCEL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from operator import itemgetter | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
from langchain_community.chat_message_histories.firestore import ( | |
FirestoreChatMessageHistory, | |
) | |
from langchain_community.retrievers.google_vertex_ai_search import ( | |
GoogleVertexAIMultiTurnSearchRetriever, | |
) | |
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.runnables.utils import ConfigurableField, ConfigurableFieldSpec | |
from langchain_google_vertexai.chat_models import ChatVertexAI | |
condense_question_template = ( | |
"Given the following conversation and a follow up question, rephrase the " | |
"follow up question to be a standalone question, in French language.\n" | |
"\n" | |
"Chat History:\n" | |
"{history}\n" | |
"Follow Up Input: {input}\n" | |
"Standalone question:" | |
) | |
condense_question_prompt = PromptTemplate.from_template(condense_question_template) | |
combine_docs_template = ( | |
"Use the following optional pieces of information to fullfil the user's " | |
"request in French and in markdown format.\n" | |
"\n" | |
"Potentially Useful Information:\n" | |
"{context}" | |
) | |
combine_docs_prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(combine_docs_template), | |
HumanMessagePromptTemplate.from_template("{input}"), | |
] | |
) | |
document_prompt = PromptTemplate.from_template("Page {page}: {page_content}") | |
def create_retriever( | |
project: str, | |
location: str = "us-central1", | |
*, | |
search_location: str = "global", | |
search_data_store: str, | |
): | |
llm = ChatVertexAI( | |
project=project, | |
location=location, | |
model_name="gemini-pro", | |
max_output_tokens=4096, | |
temperature=0.5, | |
streaming=True, | |
convert_system_message_to_human=True, | |
).configurable_fields( | |
temperature=ConfigurableField( | |
id="temperature", | |
name="LLM Temperature", | |
description="The temperature of the LLM", | |
) | |
) | |
base_retriever = GoogleVertexAIMultiTurnSearchRetriever( | |
project_id=project, | |
location_id=search_location, | |
data_store_id=search_data_store, | |
) | |
retriever = MultiQueryRetriever.from_llm(llm=llm, retriever=base_retriever) | |
standalone_question_chain = RunnablePassthrough.assign( | |
standalone_question=condense_question_prompt | llm | StrOutputParser(), | |
) | |
retriever_chain = RunnablePassthrough.assign( | |
source_documents=itemgetter("standalone_question") | retriever, | |
) | |
answer_chain = RunnablePassthrough.assign( | |
output={ | |
"context": itemgetter("source_documents"), | |
"input": itemgetter("input"), | |
} | |
| create_stuff_documents_chain(llm, combine_docs_prompt), | |
) | |
chain = standalone_question_chain | retriever_chain | answer_chain | |
return RunnableWithMessageHistory( | |
chain, | |
lambda user_id, session_id: FirestoreChatMessageHistory( | |
collection_name="chat-history", | |
user_id=user_id, | |
session_id=session_id, | |
), | |
input_messages_key="input", | |
history_messages_key="history", | |
history_factory_config=[ | |
ConfigurableFieldSpec( | |
id="user_id", | |
annotation=str, | |
name="User ID", | |
description="Unique identifier for the user.", | |
is_shared=True, | |
), | |
ConfigurableFieldSpec( | |
id="session_id", | |
annotation=str, | |
name="Session ID", | |
description="Unique identifier for the session.", | |
is_shared=True, | |
), | |
], | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment