Skip to content

Instantly share code, notes, and snippets.

@inchoate
Last active August 10, 2024 12:32
Show Gist options
  • Save inchoate/fb0e6a2300180afc095da8415c625e9e to your computer and use it in GitHub Desktop.
Save inchoate/fb0e6a2300180afc095da8415c625e9e to your computer and use it in GitHub Desktop.
Llama Index Filtered Query Engine Example

Selectively Querying Documents in Llama Index

Note: This gist has been updated to be far simpler than the original implementation, focusing on a more streamlined approach to selectively querying documents based on metadata.

Introduction

When working with Llama Index and other Retrieval-Augmented Generation (RAG) systems, most tutorials focus on ingesting and querying a single document. You typically read the document from a source, parse it, embed it, and store it in your vector store. Once there, querying is straightforward. But what if you have multiple documents and want to selectively query only one, such as Document #2 (doc_id=2), from your vector store?

This article demonstrates how to encapsulate the creation of a filtered query engine, which allows you to specify the nodes to query based on custom metadata. This approach provides a more structured and efficient way to retrieve relevant information, making it easier to manage and scale your querying process.

MetadataFilters: Customizing Your Query Scope

To enable selective querying, you need to associate metadata with each node during ingestion. This metadata can include any relevant information, such as document IDs, categories, or other attributes. By attaching metadata, you can later use MetadataFilters to narrow down the scope of your query. Attaching metadata in Llama Index is quite easy. When using the reader, just pass in the additional metadata fields:

# Read all HTML files in the directory using Unstructured
# To each file, add the company_id we fetched from the DB for this document.
documents = SimpleDirectoryReader(
    input_dir=source_directory,
    file_extractor={".html": UnstructuredReader()},
    file_metadata=lambda x: {"company_id": int(company_id)},
    required_exts=[".html"],
    recursive=True,
).load_data()

The goal here is to create a query engine that can dynamically filter nodes at query time based on metadata, rather than predefining the filter criteria when setting up the query engine. This approach offers greater flexibility and allows you to tailor the query to specific needs without rebuilding the query engine.

Example: Querying by company_id

Imagine you have a collection of webpages from various companies, each identified by a unique company_id. You want to query these documents to find out, for instance, "What products does {company_id} sell?" while setting the company_id at query time.

Updated Code for Creating a Filtered Query Engine

Here's the new, simplified code that encapsulates the logic for creating a filtered query engine:

import logging
import os
from urllib.parse import urlparse

from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.settings import Settings
from llama_index.core.vector_stores import FilterOperator, MetadataFilter, MetadataFilters
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.postgres import PGVectorStore


def init_vector_store():
    """
    Creates a vector store from the postgres connection string.
    """
    original_conn_string = os.getenv("PG_CONNECTION_STRING")
    if not original_conn_string:
        raise ValueError("PG_CONNECTION_STRING environment variable is not set.")

    conn_string = original_conn_string.replace(
        urlparse(original_conn_string).scheme + "://", "postgresql+psycopg2://"
    )
    async_conn_string = original_conn_string.replace(
        urlparse(original_conn_string).scheme + "://", "postgresql+asyncpg://"
    )
    vector_store = PGVectorStore(
        connection_string=conn_string,
        async_connection_string=async_conn_string,
        schema_name=os.getenv("PGVECTOR_SCHEMA", "public"),
        table_name=os.getenv("PGVECTOR_TABLE", "company_embeddings"),
        embed_dim=int(os.getenv("EMBEDDING_DIM", 1024)),
    )
    logging.debug("Initialized vector store.")
    return vector_store


def init_embed_model(use_globally=True) -> BaseEmbedding:
    """
    Creates an embedding model from the HuggingFace library.
    """
    config = {"model_name": os.getenv("EMBEDDING_MODEL")}
    embed_model = HuggingFaceEmbedding(**config, trust_remote_code=True)

    if use_globally:
        Settings.embed_model = embed_model

    logging.info(f"Initialized embed model: {config}")
    return embed_model


def create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10):
    """
    Creates a query engine that filters results based on a company_id.
    """
    filters = MetadataFilters(
        filters=[
            MetadataFilter(
                key="company_id",
                value=company_id,
                operator=FilterOperator.EQ,
            )
        ]
    )
    index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embed_model)
    vector_retriever = index.as_retriever(similarity_top_k=top_k, filters=filters, top_k=top_k)
    query_engine = RetrieverQueryEngine(retriever=vector_retriever)
    return query_engine


vector_store = init_vector_store()
embed_model = init_embed_model()

company_id = 65
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)

company_id = 114
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)

This updated approach simplifies the creation of a filtered query engine, allowing you to dynamically filter results based on metadata at query time. The new code is more streamlined and easier to understand, making it more efficient for selective querying.

By encapsulating the logic for filtered querying, you not only simplify the process but also make your codebase more maintainable and adaptable to future needs. This approach allows you to dynamically filter at query time, providing a flexible and efficient way to interact with your data. Happy querying!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment