Created
January 11, 2024 15:52
-
-
Save hayunjong83/1c800363fbbaa06c192d10bc051a6844 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################################################### | |
# # | |
# RAG(Retrieval Augmented Generation) example # | |
# <hayunjong83@gmail.com> # | |
# # | |
# Used: # | |
# Application : Streamlit # | |
# LLM Framework : LangChain # | |
# LLM Related : llama-2 7B, GPT4All(embedding model) # | |
# # | |
# Base code : # | |
# https://github.com/jmorganca/ollama/blob/main/examples/langchain-python-rag-document/README.md # | |
# # | |
################################################################################################### | |
import streamlit as st | |
import os | |
from langchain_community.llms import Ollama | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.embeddings import GPT4AllEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.base import BaseCallbackHandler | |
# Streamlit에서 LLM이 생성한 답변을 스트리밍 방식으로 표현하기 위한 콜백클래스 | |
# 참고) LangChain Stream에 관한 https://discuss.streamlit.io/t/langchain-stream/43782 | |
class StreamHandler(BaseCallbackHandler): | |
def __init__(self, container, initial_text="", display_method='markdown'): | |
self.container = container | |
self.text = initial_text | |
self.display_method = display_method | |
def on_llm_new_token(self, token: str, **kwargs) -> None: | |
self.text += token | |
display_function = getattr(self.container, self.display_method, None) | |
if display_function is not None: | |
display_function(self.text) | |
else: | |
raise ValueError(f"Invalid display_method: {self.display_method}") | |
# 실행 시 웹 브라우저에 표시되는 탭 제목 | |
st.set_page_config(page_title = "RAG example") | |
# Streamlit 애플리케이션에서 표시되는 제목 | |
st.title("Llama2 Chatbot using Ollama") | |
document_path = "./hanwha.pdf" | |
vector_store_path = "./hanwha_eagles" | |
rerun = False | |
if (not os.path.isdir(vector_store_path)) or rerun: | |
loader = PyPDFLoader(document_path) | |
data = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 20) | |
splits = text_splitter.split_documents(data) | |
vector_store = Chroma.from_documents(documents = splits, | |
embedding = GPT4AllEmbeddings(), | |
persist_directory = vector_store_path) | |
else: | |
vector_store = Chroma(persist_directory = vector_store_path, embedding_function = GPT4AllEmbeddings()) | |
# 검색결과(context)를 활용하여 답변을 구성하도록 프롬프트 작성 | |
template = """ | |
Use the following pieces of context to answer the question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Keep the answer as concise as possible. | |
{context} | |
Question: {question} | |
Helpful Answer: | |
""" | |
QA_CHAIN_PROMPT = PromptTemplate( | |
input_variables=["context", "question"], | |
template=template, | |
) | |
# 해당 사용자 세션에서 채팅 기록을 저장할 messages 초기화 | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# 현재까지의 채팅 기록을 화면에 표시 | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# 사용자의 새로운 질문 입력 이후의 동작 | |
if prompt := st.chat_input("What do you want to know?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# 사용자의 질문을 화면에 표시 | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# 사용자의 답변을 스트리밍 방식으로 표현 | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
stream_handler = StreamHandler(message_placeholder, display_method='write') | |
llm = Ollama( | |
model = "llama2", | |
callback_manager = CallbackManager([stream_handler]), | |
) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm, | |
retriever=vector_store.as_retriever(), | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, | |
) | |
full_response = qa_chain({"query": prompt}) | |
st.session_state.messages.append({"role": "assistant", "content": full_response["result"]}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment