Note: This gist has been updated to be far simpler than the original implementation, focusing on a more streamlined approach to selectively querying documents based on metadata.
When working with Llama Index and other Retrieval-Augmented Generation (RAG) systems, most tutorials focus on ingesting and querying a single document. You typically read the document from a source, parse it, embed it, and store it in your vector store. Once there, querying is straightforward. But what if you have multiple documents and want to selectively query only one, such as Document #2 (doc_id=2
), from your vector store?
This article demonstrates how to encapsulate the creation of a filtered query engine, which allows you to specify the nodes to query based on custom metadata. This approach provides a more structured and efficient way to retrieve relevant information, making it easier to manage and scale your querying process.
To enable selective querying, you need to associate metadata with each node during ingestion. This metadata can include any relevant information, such as document IDs, categories, or other attributes. By attaching metadata, you can later use MetadataFilters
to narrow down the scope of your query. Attaching metadata in Llama Index is quite easy. When using the reader, just pass in the additional metadata fields:
# Read all HTML files in the directory using Unstructured
# To each file, add the company_id we fetched from the DB for this document.
documents = SimpleDirectoryReader(
input_dir=source_directory,
file_extractor={".html": UnstructuredReader()},
file_metadata=lambda x: {"company_id": int(company_id)},
required_exts=[".html"],
recursive=True,
).load_data()
The goal here is to create a query engine that can dynamically filter nodes at query time based on metadata, rather than predefining the filter criteria when setting up the query engine. This approach offers greater flexibility and allows you to tailor the query to specific needs without rebuilding the query engine.
Imagine you have a collection of webpages from various companies, each identified by a unique company_id
. You want to query these documents to find out, for instance, "What products does {company_id} sell?" while setting the company_id
at query time.
Here's the new, simplified code that encapsulates the logic for creating a filtered query engine:
import logging
import os
from urllib.parse import urlparse
from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.settings import Settings
from llama_index.core.vector_stores import FilterOperator, MetadataFilter, MetadataFilters
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.postgres import PGVectorStore
def init_vector_store():
"""
Creates a vector store from the postgres connection string.
"""
original_conn_string = os.getenv("PG_CONNECTION_STRING")
if not original_conn_string:
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
conn_string = original_conn_string.replace(
urlparse(original_conn_string).scheme + "://", "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
urlparse(original_conn_string).scheme + "://", "postgresql+asyncpg://"
)
vector_store = PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=os.getenv("PGVECTOR_SCHEMA", "public"),
table_name=os.getenv("PGVECTOR_TABLE", "company_embeddings"),
embed_dim=int(os.getenv("EMBEDDING_DIM", 1024)),
)
logging.debug("Initialized vector store.")
return vector_store
def init_embed_model(use_globally=True) -> BaseEmbedding:
"""
Creates an embedding model from the HuggingFace library.
"""
config = {"model_name": os.getenv("EMBEDDING_MODEL")}
embed_model = HuggingFaceEmbedding(**config, trust_remote_code=True)
if use_globally:
Settings.embed_model = embed_model
logging.info(f"Initialized embed model: {config}")
return embed_model
def create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10):
"""
Creates a query engine that filters results based on a company_id.
"""
filters = MetadataFilters(
filters=[
MetadataFilter(
key="company_id",
value=company_id,
operator=FilterOperator.EQ,
)
]
)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embed_model)
vector_retriever = index.as_retriever(similarity_top_k=top_k, filters=filters, top_k=top_k)
query_engine = RetrieverQueryEngine(retriever=vector_retriever)
return query_engine
vector_store = init_vector_store()
embed_model = init_embed_model()
company_id = 65
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)
company_id = 114
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)
This updated approach simplifies the creation of a filtered query engine, allowing you to dynamically filter results based on metadata at query time. The new code is more streamlined and easier to understand, making it more efficient for selective querying.
By encapsulating the logic for filtered querying, you not only simplify the process but also make your codebase more maintainable and adaptable to future needs. This approach allows you to dynamically filter at query time, providing a flexible and efficient way to interact with your data. Happy querying!