Skip to content

Instantly share code, notes, and snippets.

@tuanlda78202
Created June 11, 2024 03:09
Show Gist options
  • Save tuanlda78202/eb7ba4921cd2348d171a233cc0bdb227 to your computer and use it in GitHub Desktop.
Save tuanlda78202/eb7ba4921cd2348d171a233cc0bdb227 to your computer and use it in GitHub Desktop.
Vietnamese RAG (Vistral + Multilingual E5) with Langchain and Streamlit UI
import os
import streamlit as st
from streamlit_chat import message
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
# LangChain Core imports
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
# LangChain Chroma imports
from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
# LangChain LLM imports
from langchain.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# Environment setup
os.environ["LANGCHAIN_TRACING_V2"] = "true"
load_dotenv()
# Streamlit configuration
st.set_page_config(page_title="RAG for VN-LLMs", page_icon="🤗")
st.title("RAG for VN-LLMs 🔎")
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# * RAG Setup
# Load embedding from disk
embd = HuggingFaceEmbeddings(
model_name="embaas/sentence-transformers-multilingual-e5-base",
model_kwargs={"device": "cuda"},
)
vectorstore_disk = Chroma(
persist_directory="chroma_db/",
embedding_function=embd,
)
retriever = vectorstore_disk.as_retriever(search_kwargs={"k": 5})
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# * LLM Setup
if "llm" not in st.session_state:
# * GGUF
model_name = "uonlp/Vistral-7B-Chat-gguf"
model_file = "ggml-vistral-7B-chat-q4_0.gguf"
model_path = hf_hub_download(model_name, filename=model_file)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
st.session_state.llm = LlamaCpp(
model_path=model_path,
n_gpu_layers=-1,
n_batch=512,
n_ctx=2048,
f16_kv=True,
callback_manager=callback_manager,
verbose=False,
)
def get_extracted_answer(input_text):
answer_marker = "Câu trả lời:\n"
answer_start = input_text.find(answer_marker) + len(answer_marker)
return input_text[answer_start:].strip()
def get_response(query, chat_history):
template = """
Bạn là một trợ lý cho các nhiệm vụ hỏi-đáp.
Chỉ sử dụng các thông tin đã truy xuất sau đây để trả lời câu hỏi.
Nếu bạn không biết câu trả lời, chỉ cần nói rằng bạn không biết.
Câu hỏi: {question} \n
Ngữ cảnh: {context}\n
Câu trả lời:
"""
prompt = ChatPromptTemplate.from_template(template)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| st.session_state.llm
| StrOutputParser()
)
return chain.stream(query)
# * Conversation handling
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
user_query = st.chat_input("Ask me everything")
# * User input processing
if user_query is not None and user_query != "":
st.session_state.chat_history.append(HumanMessage(user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
ai_response = st.write_stream(
get_response(user_query, st.session_state.chat_history)
)
st.session_state.chat_history.append(AIMessage(ai_response))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment