Created
November 11, 2023 22:21
-
-
Save richdrummer33/ebf6c624b648c3c9221e2d22430a8001 to your computer and use it in GitHub Desktop.
RAG using LlamaIndex's DocumentSummaryIndexLLMRetriever and mistral-7b-instruct LLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# WORKS!!! | |
import winsound | |
import torch | |
import time | |
##################################################################### | |
# MISC CLASSES | |
##################################################################### | |
# Example Nodes as a knowledge base | |
from llama_index.schema import Node | |
test_nodes = [ | |
Node(text="The Earth revolves around the Sun."), | |
Node(text="The Moon orbits the Earth."), | |
Node(text="Gravity is the force of attraction between two bodies."), | |
Node(text="Photosynthesis is the process by which green plants make their own food."), | |
Node(text="Electricity can be generated from renewable sources such as wind or solar energy.") | |
# ... you can add more nodes as needed | |
] | |
class NotificationType: | |
WARNING = "C:\\Windows\\Media\\Windows Exclamation.wav" | |
SUCCESS = "C:\\Windows\\Media\\Speech On.wav" | |
def play_notification_sound(notification_type): | |
if notification_type == NotificationType.WARNING: | |
sound_path = NotificationType.WARNING | |
elif notification_type == NotificationType.SUCCESS: | |
sound_path = NotificationType.SUCCESS | |
winsound.PlaySound(sound_path, winsound.SND_FILENAME) | |
import re | |
from typing import Tuple, List | |
import re | |
from typing import Tuple, List | |
from itertools import zip_longest | |
import re | |
from typing import Tuple, List | |
from itertools import zip_longest | |
def parse_choice_select_answer_fn( | |
answer: str, num_choices: int, raise_error: bool = False | |
) -> Tuple[List[int], List[float]]: | |
"""Parse choice select answer function.""" | |
answer_lines = answer.split("\n") | |
answer_nums = [] | |
answer_relevances = [] | |
print("parsing lines: \n" + str(answer_lines)) | |
# Temporary storage for document numbers | |
temp_doc_nums = [] | |
# Temporary storage for relevance scores | |
temp_relevances = [] | |
for answer_line in answer_lines: | |
# Check for document lines and extract the number | |
doc_match = re.match(r'Document (\d+):', answer_line) | |
if doc_match: | |
temp_doc_nums.append(int(doc_match.group(1))) | |
# Check for relevance score line and split the scores | |
relevance_match = re.match(r'Relevance score: (.+)', answer_line) | |
if relevance_match: | |
scores_str = relevance_match.group(1) | |
temp_relevances = [float(score.strip()) for score in scores_str.split(',')] | |
# Match documents with relevance scores | |
for doc_num, relevance in zip_longest(temp_doc_nums, temp_relevances, fillvalue=0.0): | |
answer_nums.append(int(doc_num)) | |
answer_relevances.append(int(relevance)) | |
log_opt = "" | |
if not answer_nums: | |
if raise_error: | |
raise ValueError("No valid answer numbers found.") | |
else: | |
print("No valid answer numbers found.") | |
answer_nums.append(0) | |
answer_relevances.append(int(0)) | |
if len(answer_nums) == 0: | |
print("No answer nums added from relevance matching. Adding top 3.") | |
log_opt = "(top 3)" | |
answer_nums = answer_nums[:3] | |
answer_relevances = answer_relevances[:3] | |
relevance_override = 10 | |
for x in answer_relevances: | |
if x == 0: | |
answer_relevances[answer_relevances.index(x)] = int(relevance_override) | |
relevance_override -= int(1) | |
print("\nanswer_nums " + log_opt + ": " + str(answer_nums)) | |
print("answer_relevances " + log_opt + ": " + str(answer_relevances) + "\n") | |
return answer_nums, answer_relevances | |
def __init__(self, name, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._metadata = {'name': name} # Initialize metadata with name | |
@property | |
def metadata(self): | |
# Return the metadata dictionary | |
return self._metadata | |
def __getattr__(self, name: str): | |
# Since metadata is now a regular attribute, we don't need to override __getattr__ for it | |
return super().__getattr__(name) | |
def list_all_doc_key_values(index_summary: DocumentSummaryIndex) -> List[Dict[str, str]]: | |
""" | |
List all document key-value pairs from the provided IndexDocumentSummary instance. | |
Returns a list of dictionaries, each containing 'doc_id' and 'summary_id'. | |
""" | |
all_key_values = [] | |
for doc_id, summary_id in index_summary.doc_id_to_summary_id.items(): | |
all_key_values.append({'doc_id': doc_id, 'summary_id': summary_id}) | |
print(f"doc_id: {doc_id}, summary_id: {summary_id}") | |
return all_key_values | |
##################################################################### | |
# PROGRAM | |
##################################################################### | |
##################### | |
#### Fields and Definitions | |
##################### | |
model_path = "./models/mistral-7b-instruct-v0.1.Q4_K_M.gguf" | |
model_embeddings_path = "./sentence-transformers/all-mpnet-base-v2" | |
data_path = "D:/Git/Unseen/Assets/Code/_Core/Webcam/MediaPipe" #"D:/Git/EscapeRoom3DGitLab/Assets/Scripts/Puzzle 8" #"C:/Users/richd/Desktop/test-rag" #"D:/Git/ebook-GPT-translator-refined" | |
config_llm_path = "./models/Mistral-7B-Instruct-v0.1/config.json" | |
use_gpt = False | |
use_doc_summary = False | |
SUMMARY_QUERY = ( | |
"""You are a Unity developer. Write a class summary: | |
Class Definition: Name, base class, and interfaces (if applicable). | |
Class Role: Define the class's functionality. If applicable, also note any significant class references, dependents or dependencies. | |
Methods: List all method names, with attributes in square brackets if applicable. | |
Features: Take note of any special features such as Photon RPC calls""" | |
) | |
##################### | |
#### Load/imit the local gguf LLM (via llama cpp) | |
##################### | |
from llama_index.llms import LlamaCPP | |
from llama_index.llms.llama_utils import ( | |
messages_to_prompt, | |
completion_to_prompt, | |
) | |
print("\033[95m\nLoading Model...\n\033[0m") | |
if not use_gpt: | |
llm = LlamaCPP( | |
# You can pass in the URL to a GGML model to download it automatically | |
model_url=None, | |
# optionally, you can set the path to a pre-downloaded model instead of model_url | |
model_path=model_path, | |
temperature=0.1, | |
max_new_tokens=256, | |
# llama2 has a context window of 4096 tokens, but we set it lower to allow for some wiggle room | |
context_window=3900, | |
# kwargs to pass to __call__() | |
generate_kwargs={}, | |
# kwargs to pass to __init__() | |
# set to at least 1 to use GPU | |
model_kwargs={"n_gpu_layers": 50}, | |
# transform inputs into Llama2 format | |
messages_to_prompt=messages_to_prompt, | |
completion_to_prompt=completion_to_prompt, | |
verbose=True, | |
) | |
else: | |
import openai | |
import os | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
from llama_index.llms import OpenAI | |
llm = OpenAI(temperature=0.1, model="gpt-4") | |
##################### | |
### Embeddings and service context | |
### FOR LLAMA CPP (GGUF COMPAT): https://gpt-index.readthedocs.io/en/latest/examples/llm/llama_2_llama_cpp.html | |
##################### | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
print("\033[95m\nEmbeddings...\n\033[0m") | |
embed_model = HuggingFaceEmbeddings(model_name=model_embeddings_path) | |
##################### | |
### Set things up for indexer | |
### REF: https://blog.llamaindex.ai/a-new-document-summary-index-for-llm-powered-qa-systems-9a32ece2f9ec | |
##################### | |
print("\033[95m\nIndexing...\n\033[0m") | |
from llama_index import ( | |
SimpleDirectoryReader, | |
LLMPredictor, # added for doc summary | |
VectorStoreIndex, | |
ServiceContext, | |
set_global_service_context | |
) | |
# Defines the llm and embed models and chunk size to retreive | |
service_context = ServiceContext.from_defaults( | |
chunk_size=2048, # 1024 | |
llm=llm, | |
embed_model=embed_model | |
) | |
set_global_service_context(service_context) # Necessary? Maybe not | |
##################### | |
### Set up response synth - generates response from llm for a user query and a given set of text chunks | |
### REF: https://gpt-index.readthedocs.io/en/latest/module_guides/querying/response_synthesizers/root.html | |
##################### | |
from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer | |
# For synthesizing summaries for the docs https://blog.llamaindex.ai/... | |
response_synthesizer = get_response_synthesizer(response_mode=ResponseMode.TREE_SUMMARIZE) | |
# EG. USAGE: response = response_synthesizer.synthesize("query text", nodes=[Node(text="text"), ...]) | |
### Optional user verifiation of summaries | |
# test_query_llm_response = input("\nTest response synth on test data: ") | |
# if len(test_query_llm_response) > 0: | |
# response = response_synthesizer.synthesize( | |
# test_query_llm_response, | |
# nodes=[Node(text="Electricity can be generated from renewable sources such as wind or solar energy")] | |
# ) | |
##################### | |
### Index the docs | |
##################### | |
from llama_index import VectorStoreIndex, SimpleDirectoryReader, DocumentSummaryIndex, StorageContext | |
from llama_index.indices.loading import load_index_from_storage | |
print("\033[95m\nGetting documents...\n\033[0m") | |
reader = SimpleDirectoryReader(data_path, recursive=True, exclude=['*.meta']) | |
documents = reader.load_data() | |
#print("Num docs: " + str(documents.__doc__.count())) | |
doc_summary_index = None | |
try: | |
storage_context = StorageContext.from_defaults(persist_dir="index") | |
doc_summary_index = load_index_from_storage(storage_context) | |
print("\033[95m\nLoaded index from storage...\n\033[0m") | |
except: | |
doc_summary_index = DocumentSummaryIndex.from_documents( | |
documents, | |
service_context=service_context, | |
summary_query=SUMMARY_QUERY, | |
response_synthesizer=response_synthesizer, # will use response synth to generate llm response to retreived chunks | |
show_progress=True | |
) | |
doc_summary_index.storage_context.persist("index") | |
### Optional user verifiation of summaries | |
doc_id = "_" | |
while len(doc_id) > 0: | |
doc_id = input("\nEnter doc-ID to print doc summary: ") | |
try: | |
summary = doc_summary_index.get_document_summary(doc_id) | |
print(summary) | |
except Exception as e: | |
print("GET DOC SUMMARY FAILED " + str(e)) | |
##################### | |
### Set up docs retreiver | |
### REF: https://gpt-index.readthedocs.io/en/latest/examples/index_structs/doc_summary/DocSummary.html | |
##################### | |
from llama_index.indices.document_summary import DocumentSummaryIndexLLMRetriever | |
retriever = DocumentSummaryIndexLLMRetriever( | |
doc_summary_index, | |
# choice_select_prompt=choice_select_prompt, | |
# choice_batch_size=choice_batch_size, | |
# format_node_batch_fn=format_node_batch_fn, | |
# choice_batch_size=10, | |
# choice_top_k=5, # 5 gave great answer! | |
parse_choice_select_answer_fn=parse_choice_select_answer_fn, | |
service_context=service_context | |
) | |
# The retriever will retrieve a set of relevant nodes for a given index. | |
# Optional user verifiation of retreival matching | |
retreival_match = "_" | |
while len(retreival_match) > 0: | |
retreival_match = input("\nEnter string to test retreival: ") | |
try: | |
retrieved_nodes = retriever.retrieve(retreival_match) | |
try: | |
print("retrieved_nodes: " + str(len(retrieved_nodes))) | |
print("score: " + str(retrieved_nodes[0].score)) | |
print("text: " + retrieved_nodes[0].node.get_text()) | |
except Exception as e: | |
print("PRINTING RETREIVED NODES FAILED: " + str(e)) | |
except Exception as e: | |
print(f"An exception occurred: {e}") | |
traceback.print_exc() | |
stack_trace = traceback.format_exc() | |
print(stack_trace) | |
##################### | |
### Set up query engine... | |
### Response/summarization mode can include auto-iterative prompt refinement | |
### refine, compact, tree_summarize, etc | |
### REF (docs retreival): https://gpt-index.readthedocs.io/en/latest/examples/index_structs/doc_summary/DocSummary.html | |
##################### | |
print("\033[95m\nQuery engine...\n\033[0m") | |
from llama_index.query_engine import RetrieverQueryEngine | |
query_engine = RetrieverQueryEngine( | |
retriever=retriever, | |
response_synthesizer=response_synthesizer | |
) | |
#query_engine = doc_summary_index.as_query_engine(response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True) | |
play_notification_sound(NotificationType.SUCCESS) | |
##################### | |
### Promt time! | |
##################### | |
from typing import Dict, List | |
while True: | |
prompt = input("\nEnter prompt: ") | |
if "list_ids" in str(prompt): | |
print("\033[95m\nListing IDs...\n\033[0m") | |
list_all_doc_key_values(doc_summary_index) | |
elif "doc_id" in str(prompt): | |
print("\033[95m\nQuerying docstore...\n\033[0m") | |
doc_id = str(prompt).split("doc_id:")[1].split("\n")[0] | |
doc_id = doc_id.strip() | |
doc_summary = doc_summary_index.get_document_summary(doc_id) | |
print(doc_summary) | |
else: | |
print("\033[95m\nGenerating output from promt...\n\033[0m") | |
response = query_engine.query(prompt) | |
canStream = False | |
try: | |
response.print_response_stream() | |
canStream = True | |
except: | |
print("cannot stream") | |
if not canStream: | |
print("\nResponse: ") | |
print(str(response)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment