Skip to content

Instantly share code, notes, and snippets.

@hayunjong83
Created January 11, 2024 15:52
Show Gist options
  • Save hayunjong83/1c800363fbbaa06c192d10bc051a6844 to your computer and use it in GitHub Desktop.
Save hayunjong83/1c800363fbbaa06c192d10bc051a6844 to your computer and use it in GitHub Desktop.
###################################################################################################
# #
# RAG(Retrieval Augmented Generation) example #
# <hayunjong83@gmail.com> #
# #
# Used: #
# Application : Streamlit #
# LLM Framework : LangChain #
# LLM Related : llama-2 7B, GPT4All(embedding model) #
# #
# Base code : #
# https://github.com/jmorganca/ollama/blob/main/examples/langchain-python-rag-document/README.md #
# #
###################################################################################################
import streamlit as st
import os
from langchain_community.llms import Ollama
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import GPT4AllEmbeddings
from langchain.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.base import BaseCallbackHandler
# Streamlit에서 LLM이 생성한 답변을 스트리밍 방식으로 표현하기 위한 콜백클래스
# 참고) LangChain Stream에 관한 https://discuss.streamlit.io/t/langchain-stream/43782
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text="", display_method='markdown'):
self.container = container
self.text = initial_text
self.display_method = display_method
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
display_function = getattr(self.container, self.display_method, None)
if display_function is not None:
display_function(self.text)
else:
raise ValueError(f"Invalid display_method: {self.display_method}")
# 실행 시 웹 브라우저에 표시되는 탭 제목
st.set_page_config(page_title = "RAG example")
# Streamlit 애플리케이션에서 표시되는 제목
st.title("Llama2 Chatbot using Ollama")
document_path = "./hanwha.pdf"
vector_store_path = "./hanwha_eagles"
rerun = False
if (not os.path.isdir(vector_store_path)) or rerun:
loader = PyPDFLoader(document_path)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 20)
splits = text_splitter.split_documents(data)
vector_store = Chroma.from_documents(documents = splits,
embedding = GPT4AllEmbeddings(),
persist_directory = vector_store_path)
else:
vector_store = Chroma(persist_directory = vector_store_path, embedding_function = GPT4AllEmbeddings())
# 검색결과(context)를 활용하여 답변을 구성하도록 프롬프트 작성
template = """
Use the following pieces of context to answer the question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:
"""
QA_CHAIN_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template=template,
)
# 해당 사용자 세션에서 채팅 기록을 저장할 messages 초기화
if "messages" not in st.session_state:
st.session_state.messages = []
# 현재까지의 채팅 기록을 화면에 표시
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 사용자의 새로운 질문 입력 이후의 동작
if prompt := st.chat_input("What do you want to know?"):
st.session_state.messages.append({"role": "user", "content": prompt})
# 사용자의 질문을 화면에 표시
with st.chat_message("user"):
st.markdown(prompt)
# 사용자의 답변을 스트리밍 방식으로 표현
with st.chat_message("assistant"):
message_placeholder = st.empty()
stream_handler = StreamHandler(message_placeholder, display_method='write')
llm = Ollama(
model = "llama2",
callback_manager = CallbackManager([stream_handler]),
)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vector_store.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
)
full_response = qa_chain({"query": prompt})
st.session_state.messages.append({"role": "assistant", "content": full_response["result"]})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment