Created
February 12, 2024 14:17
-
-
Save TheMcSebi/f8bd7e1de6ae4464407f573ecb008fbe to your computer and use it in GitHub Desktop.
Perform RAGatuille search and postprocess using an LLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ragatouille import RAGPretrainedModel | |
import json, cmd, ollama # pip install ollama | |
# first, install ollama: | |
# curl -fsSL https://ollama.com/install.sh | sh | |
# second, download llama2:chat: | |
# ollama run llama2:chat | |
llm = "llama2:chat" | |
llm_context_size = 8192 | |
DEFAULT_SOURCES_COUNT = 3 | |
# index_src = ".ragatouille/colbert/indexes/obsidian_index" | |
index_src = ".ragatouille/colbert/indexes/wiki_index" | |
RAG = RAGPretrainedModel.from_index(index_src) | |
def perform_ollama_request(inp: str) -> str: | |
""" | |
Performs a LLM request on some input | |
""" | |
messages = [ | |
{ | |
"role": "user", | |
"content": inp, | |
} | |
] | |
fulltext = "" | |
for chunk in ollama.chat(llm, messages=messages, stream=True, options={"temperature": 0.2, "num_ctx": llm_context_size}): | |
fulltext += chunk["message"]["content"] | |
return fulltext.strip() | |
def perform_search_request(query: str, sources_count = DEFAULT_SOURCES_COUNT) -> str: | |
""" | |
Performs a serch to the document database | |
""" | |
results = RAG.search(query, k=sources_count) | |
return results | |
def perform_full_request(query: str) -> str: | |
""" | |
Performs a query enhanced with documents retrieved from the database and runs it through a LLM for analysis | |
""" | |
results = perform_search_request(query) | |
# build prompt | |
prompt = "Please answer ONLY the following question by taking information only from the information sources after the query. If the required information is not contained in the query, answer that you don't know the resolution. If the sources contain similar information, explain them in regards to the query.\n\nQUERY: " + query + "\n\nINFORMATION SOURCES: \n\n" | |
sources = "" | |
for i,r in enumerate(results, start=1): | |
prompt += f"{i}. " + r["content"] + "\n\n" | |
sources += f"{i}. " + r["document_metadata"]["path"] + "\n" | |
prompt = prompt.strip() | |
# perform llm analysis | |
response = perform_ollama_request(prompt) + "\n\n" | |
# add sources | |
response += sources | |
return response.strip() | |
class RAGQuery(cmd.Cmd): | |
def __init__(self, completekey: str = "tab", stdin = None, stdout = None) -> None: | |
super().__init__(completekey, stdin, stdout) | |
self.prompt = "RAG> " | |
self.stdout.write("Use this command line to enter a search query. \n\nIf prepending the query with 'ollama ', the query will only be ran through the attached large language model. \nPrepending 'rag ' will skip postprocessing of the resulting documents via a LLM.\n\n") | |
def precmd(self, line: str) -> str: | |
if line.startswith("ollama "): | |
response = perform_ollama_request(line[7:]) | |
elif line.startswith("rag "): | |
documents = perform_search_request(line[4:], sources_count=10) | |
response = json.dumps(documents, indent=2) | |
else: | |
response = perform_full_request(line) | |
self.stdout.write(response + "\n") | |
return "" | |
if __name__ == "__main__": | |
# user_input = "What is the average hair length of cats?" | |
# results = perform_search_request(user_input) | |
# for r in results[:3]: | |
# print(r["document_metadata"]["path"]) | |
# print(r["content"]) | |
# answer = perform_full_request(user_input) | |
# print(answer) | |
interpreter = RAGQuery() | |
interpreter.cmdloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment