Last active
January 5, 2024 23:45
-
-
Save robert-mcdermott/45bbd5d047caed60ff9e097ba42ae4c1 to your computer and use it in GitHub Desktop.
Use Ollama with a local LLM to query websites
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import requests | |
from langchain.llms import Ollama | |
from langchain.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import GPT4AllEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
def main(args): | |
""" | |
Main function to execute the retrieval QA chain. | |
Args: | |
args.model (str): The model to be used by Ollama. | |
args.website (str): URL of the website to retrieve data from. | |
args.question (str): The query or question to be asked about the website content. | |
args.base_url (str): The base URL for the Ollama. | |
args.chunk_size (int): the number of tokens to chunk the document up into | |
args.chunk_overlap (int): the number of tokens to overlap between chunks | |
args.num_docs (int): the number of matching document chunks to retrive | |
args.stop_matches (int): The number the of top matching document chunks to retrieve | |
args.system (str): The LLM system message. | |
args.temp (float): the LLM temperature setting | |
Initialize Ollama with the specified model and base_url, | |
load data from the given website, splits the text for efficient processing, | |
creates a vector store from the document splits, and finally perform a | |
retrieval-based question answering using the specified question. | |
""" | |
check_server_availability(args.base_url) | |
ollama = Ollama( | |
base_url=args.base_url, | |
model=args.model, | |
system=args.system, | |
temperature=args.temp, | |
num_ctx=2048 | |
) | |
loader = WebBaseLoader(args.website) | |
data = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap) | |
all_splits = text_splitter.split_documents(data) | |
vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings()) | |
qachain = RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever(search_kwargs={"k": args.top_matches})) | |
print("\n### model response:") | |
print(qachain({"query": args.question})) | |
def check_server_availability(base_url): | |
""" | |
Check if the Ollama server is running at the specified base URL. | |
""" | |
try: | |
response = requests.get(base_url) | |
if response.status_code == 200: | |
print(f"Successfully connected to the Ollama server at {base_url}.") | |
else: | |
print(f"Failed to connect to the Ollama server at {base_url}. Exiting.") | |
sys.exit(1) | |
except requests.ConnectionError: | |
print(f"Could not connect to the Ollama server at {base_url}. Exiting.") | |
sys.exit(1) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run a retrieval QA chain on a specified website.") | |
parser.add_argument("website", help="The website URL to retrieve data from") | |
parser.add_argument("question", help="The question to ask about the website's content") | |
parser.add_argument("--model", default="zephyr:latest", help="The model to use (default: zephyr:latest)") | |
parser.add_argument("--base_url", default="http://localhost:11434", help="The base URL for the Ollama (default: http://localhost:11434)") | |
parser.add_argument("--chunk_size", type=int, default=200, help="The document token chunk size (default: 200)") | |
parser.add_argument("--chunk_overlap", type=int, default=50, help="The amount of chunk overlap (default: 100)") | |
parser.add_argument("--top_matches", type=int, default=4, help="The number the of top matching document chunks to retrieve (default: 4)") | |
parser.add_argument("--system", default="You are a helpful assistant.", help="The system message provided to the LLM") | |
parser.add_argument("--temp", type=float, default=0.0, help="The model temperature setting (default: 0.0") | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment