Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created July 8, 2024 12:08
Show Gist options
  • Save ehzawad/fa15c4b514f743aa49902f7778db937f to your computer and use it in GitHub Desktop.
Save ehzawad/fa15c4b514f743aa49902f7778db937f to your computer and use it in GitHub Desktop.
RAG
# Cell 1: Install dependencies
!pip install -q -U transformers llama-index accelerate pypdf einops bitsandbytes
!pip install -q llama-index-llms-huggingface
!pip install -q llama-index-embeddings-huggingface
# Cell 2: Import libraries and set up warnings
import warnings
warnings.filterwarnings('ignore')
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.prompts import PromptTemplate
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.response.notebook_utils import display_source_node
import torch
# Cell 3: Set up prompts and login to Hugging Face
system_prompt = """<|SYSTEM|>#
You are a helpful, respectful and honest assistant. Always consider the chat history when answering questions.
"""
query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>")
from huggingface_hub import login
login(token='your_huggingface_token_here')
# Cell 4: Load documents
documents = SimpleDirectoryReader("/content/data").load_data()
# Cell 5: Set up embedding model
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2")
# Cell 6: Set up LLM
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=256,
generate_kwargs={"temperature": 0, "do_sample": False},
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
model_name="meta-llama/Llama-2-7b-chat-hf",
device_map="auto",
tokenizer_kwargs={"max_length": 4096},
model_kwargs={
"torch_dtype": torch.float16,
"llm_int8_enable_fp32_cpu_offload": True,
"bnb_4bit_quant_type": 'nf4',
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16,
"load_in_4bit": True
}
)
# Cell 7: Set up service context and index
service_context = ServiceContext.from_defaults(
chunk_size=2048,
chunk_overlap=50,
llm=llm,
embed_model=embed_model
)
index = VectorStoreIndex.from_documents(
documents, service_context=service_context
)
# Cell 8: Set up chat memory and query engine
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
query_engine = index.as_query_engine(
streaming=True,
similarity_top_k=2,
chat_memory=memory
)
# Cell 9: Function to handle conversation
def chat_with_rag(query):
response_stream = query_engine.query(query)
response_stream.print_response_stream()
print("\nSources:")
for node in response_stream.source_nodes:
print(f"- {node.node.get_content()[:100]}...")
# Cell 10: Main conversation loop
print("Welcome to the Conversational RAG system. Type 'exit' to end the conversation.")
while True:
user_input = input("You: ")
if user_input.lower() == 'exit':
print("Thank you for using the Conversational RAG system. Goodbye!")
break
chat_with_rag(user_input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment