Skip to content

Instantly share code, notes, and snippets.

@simonmesmith
Last active November 28, 2023 14:45
Show Gist options
  • Save simonmesmith/43b4c447584094d2a15793f0c6e60463 to your computer and use it in GitHub Desktop.
Save simonmesmith/43b4c447584094d2a15793f0c6e60463 to your computer and use it in GitHub Desktop.
Simple RAG for easily embedding documents and querying embeddings
"""
SIMPLE RAG!
This file provides a class, Collection, that makes it easy to add retrieval
augmented generation (RAG) to an application.
There are so many overly complex RAG tools out there, such as LlamaIndex and
LangChain. Even Chroma can be overly complex for some use cases, and I've
run into issues (in Streamlit Share) where Chroma's dependency on SQLite
caused a conflict that I couldn't resolve. Argh!
So I created the below to serve as a very simple, lightweight, low-dependency
solution to add RAG to an application. I also coded it to be a fairly easy
drop-in replacement for Chroma.
Its one major limitation is that it doesn't come with a database. This means
everything is stored in memory. This is fine for small collections, but if
you have a large collection, or don't want to constantly embed the same
documents, you'll probably want to use a solution with a database like Chroma.
The code below uses OpenAI embeddings. To use these, you'll need to get an
OpenAI API key and set an environment variable for OPENAI_API_KEY. If you don't
want to use OpenAI embeddings, you can swap those out for any other embeddings.
Just make sure you use the same embedding function for both adding documents
and querying them.
INSTALL:
pip install numpy scikit-learn openai
USAGE:
from simple_rag import Collection
collection = Collection()
collection.add(
ids=["1", "2"],
documents=["This is a document.", "This is another document."],
metadatas=[{"url": "http://test.com"}, {"url": "http://test.com"}],
)
results = collection.query(query_texts=["Find a document"])
print(results)
>>> {'documents': [['This is a document.', 'This is another document.']],
>>> 'distances': [[0.1539412288910481, 0.17489997771983146]], 'metadatas':
>>> [[{'url': 'http://test.com'}, {'url': 'http://test.com'}]]}
"""
import os
import numpy as np
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
class Collection:
"""A collection of documents with associated metadata and embeddings."""
def __init__(self):
self.documents = []
self.ids = []
self.metadatas = []
self.embeddings = []
def add(self, ids: list[str], documents: list[str], metadatas: list[dict]):
"""Adds documents to the collection."""
self.ids.extend(ids)
self.documents.extend(documents)
self.metadatas.extend(metadatas)
embeddings = self._embed_documents(documents)
if embeddings is not None:
self.embeddings.extend(embeddings)
def query(self, query_texts: list[str], min_distance: float = 0.3) -> dict:
"""Queries the collection for documents similar to the query texts."""
query_embeddings = self._embed_documents(query_texts)
results = {"documents": [], "distances": [], "metadatas": []}
if query_embeddings is not None:
for query_embedding in query_embeddings:
distances = (
1
- cosine_similarity([query_embedding], self.embeddings)[0]
)
relevant_indices = np.where(distances <= min_distance)[0]
results["documents"].append(
[self.documents[i] for i in relevant_indices]
)
results["distances"].append(
[distances[i] for i in relevant_indices]
)
results["metadatas"].append(
[self.metadatas[i] for i in relevant_indices]
)
return results
def _embed_documents(self, documents: list[str]) -> np.ndarray | None:
"""Embeds documents. If you want to use something other than OpenAI
embeddings, you can change up this function."""
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "text-embedding-ada-002"
max_batch_size = 2048 # Max batch size for OpenAI embeddings
embeddings = []
for i in range(0, len(documents), max_batch_size):
batch = documents[i : i + max_batch_size] # noqa
try:
response = client.embeddings.create(model=model, input=batch)
batch_embeddings = [data.embedding for data in response.data]
embeddings.extend(batch_embeddings)
except Exception as e:
print(f"Error embedding documents: {e}")
embeddings.extend([None] * len(batch))
if embeddings:
return np.array(embeddings)
else:
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment