Created
June 11, 2024 03:09
-
-
Save tuanlda78202/eb7ba4921cd2348d171a233cc0bdb227 to your computer and use it in GitHub Desktop.
Vietnamese RAG (Vistral + Multilingual E5) with Langchain and Streamlit UI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import streamlit as st | |
from streamlit_chat import message | |
from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
# LangChain Core imports | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
# LangChain Chroma imports | |
from langchain_chroma import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# LangChain LLM imports | |
from langchain.llms import LlamaCpp | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
# Environment setup | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
load_dotenv() | |
# Streamlit configuration | |
st.set_page_config(page_title="RAG for VN-LLMs", page_icon="🤗") | |
st.title("RAG for VN-LLMs 🔎") | |
# Initialize session state | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# * RAG Setup | |
# Load embedding from disk | |
embd = HuggingFaceEmbeddings( | |
model_name="embaas/sentence-transformers-multilingual-e5-base", | |
model_kwargs={"device": "cuda"}, | |
) | |
vectorstore_disk = Chroma( | |
persist_directory="chroma_db/", | |
embedding_function=embd, | |
) | |
retriever = vectorstore_disk.as_retriever(search_kwargs={"k": 5}) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# * LLM Setup | |
if "llm" not in st.session_state: | |
# * GGUF | |
model_name = "uonlp/Vistral-7B-Chat-gguf" | |
model_file = "ggml-vistral-7B-chat-q4_0.gguf" | |
model_path = hf_hub_download(model_name, filename=model_file) | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
st.session_state.llm = LlamaCpp( | |
model_path=model_path, | |
n_gpu_layers=-1, | |
n_batch=512, | |
n_ctx=2048, | |
f16_kv=True, | |
callback_manager=callback_manager, | |
verbose=False, | |
) | |
def get_extracted_answer(input_text): | |
answer_marker = "Câu trả lời:\n" | |
answer_start = input_text.find(answer_marker) + len(answer_marker) | |
return input_text[answer_start:].strip() | |
def get_response(query, chat_history): | |
template = """ | |
Bạn là một trợ lý cho các nhiệm vụ hỏi-đáp. | |
Chỉ sử dụng các thông tin đã truy xuất sau đây để trả lời câu hỏi. | |
Nếu bạn không biết câu trả lời, chỉ cần nói rằng bạn không biết. | |
Câu hỏi: {question} \n | |
Ngữ cảnh: {context}\n | |
Câu trả lời: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| st.session_state.llm | |
| StrOutputParser() | |
) | |
return chain.stream(query) | |
# * Conversation handling | |
for message in st.session_state.chat_history: | |
if isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.markdown(message.content) | |
elif isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
user_query = st.chat_input("Ask me everything") | |
# * User input processing | |
if user_query is not None and user_query != "": | |
st.session_state.chat_history.append(HumanMessage(user_query)) | |
with st.chat_message("Human"): | |
st.markdown(user_query) | |
with st.chat_message("AI"): | |
ai_response = st.write_stream( | |
get_response(user_query, st.session_state.chat_history) | |
) | |
st.session_state.chat_history.append(AIMessage(ai_response)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment