Skip to content

Instantly share code, notes, and snippets.

@thbar
Forked from cpursley/postgres_hybrid_search.ex
Created June 8, 2024 12:41
Show Gist options
  • Save thbar/9342328e2e9ec65975d8af9b435ef9dc to your computer and use it in GitHub Desktop.
Save thbar/9342328e2e9ec65975d8af9b435ef9dc to your computer and use it in GitHub Desktop.
Postgres Hybrid Search
defmodule PostgresHybridSearch do
@moduledoc """
Postgres Hybrid Search
Loosely based on:
- https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search_rrf.py
- https://github.com/Azure-Samples/rag-postgres-openai-python/blob/e30ea96ca11ca6578ca38d3428594bd98d704900/src/fastapi_app/postgres_searcher.py#L2
- https://supabase.com/docs/guides/ai/hybrid-search
- https://github.com/toranb/rag-n-drop/blob/main/lib/demo/section.ex#L30
"""
import Ecto.Query
alias Ecto.Adapters.SQL
alias Postgrex.Result
@doc """
Setup:
1. Create a tsvector column in your table ("generated" column recommended based on one or more text columns)
2. Create a GIN index on the tsvector generated column
3. Create a HNSW index on the vector embedding column
4. Generate embeddings for your documents and store them in the vector column (e.g. using BERT, etc via Bumblebee or OpenAI's API)
4. Generate embeddings for your search query and pass to query_embedding
Example Usage:
```elixir
hybrid_search(
YourApp.Repo,
YourApp.Documents.Document,
%{tsvector_column: "content_tsvector", query_string: "What is a cat?"},
%{vector_column: "content_embedding", query_embedding: [0.1, 0.2, 0.3, ...]}
)
```
"""
def hybrid_search(
repo,
schema,
%{tsvector_column: tsvector_column, query_string: query_string},
%{vector_column: vector_column, query_embedding: query_embedding},
filters \\ nil,
select_fields \\ nil,
match_count \\ 10,
full_text_weight \\ 1.0,
vector_weight \\ 1.0,
rrf_k \\ 50
)
when is_binary(query_string) and is_list(query_embedding) and is_integer(match_count) and is_float(full_text_weight) and
is_float(vector_weight) and is_integer(rrf_k) do
table_name = schema.__schema__(:source)
{where_filter, filter_params} = build_filter_clause(filters)
vector_query = build_vector_query(table_name, vector_column, where_filter, match_count)
full_text_query = build_full_text_query(table_name, tsvector_column, where_filter, match_count)
hybrid_query =
build_hybrid_query(
table_name,
vector_query,
full_text_query,
select_fields,
match_count,
full_text_weight,
vector_weight,
rrf_k
)
{sql, params} =
build_query_and_args(query_string, query_embedding, hybrid_query, vector_query, full_text_query, filter_params)
repo
|> SQL.query!(sql, params)
|> result_to_map()
|> Enum.sort_by(& &1.score, &>=/2)
|> Enum.take(match_count)
end
defp build_vector_query(table_name, vector_column, filters, match_count) do
"""
SELECT id, RANK() OVER (ORDER BY #{vector_column} <=> $1) AS rank
FROM #{table_name}
#{maybe_where(filters)}
ORDER BY #{vector_column} <=> $1
LIMIT LEAST(#{match_count}, 20) * 2
"""
end
defp build_full_text_query(table_name, tsvector_column, filters, match_count) do
"""
SELECT id, RANK() OVER (ORDER BY ts_rank_cd(#{tsvector_column}, query) DESC) AS rank
FROM #{table_name}, websearch_to_tsquery('english', $2) query
WHERE #{tsvector_column} @@ query #{maybe_where_and(filters)}
ORDER BY ts_rank_cd(#{tsvector_column}, query) DESC
LIMIT LEAST(#{match_count}, 20) * 2
"""
end
defp build_hybrid_query(
table_name,
vector_query,
full_text_query,
select_fields,
match_count,
full_text_weight,
vector_weight,
rrf_k
) do
select_fields_query = build_select_fields_query(select_fields)
"""
WITH vector_search AS (
#{vector_query}
),
fulltext_search AS (
#{full_text_query}
)
SELECT
COALESCE(1.0 / (#{rrf_k} + vector_search.rank), 0.0) * #{vector_weight} + COALESCE(1.0 / (#{rrf_k} + fulltext_search.rank), 0.0) * #{full_text_weight} AS score,
COALESCE(vector_search.id, fulltext_search.id) AS id,
#{select_fields_query}
FROM vector_search
FULL OUTER JOIN fulltext_search ON vector_search.id = fulltext_search.id
LEFT JOIN #{table_name} ON vector_search.id = #{table_name}.id
ORDER BY score DESC
LIMIT LEAST(#{match_count}, 20) * 2
"""
end
defp build_filter_clause(nil), do: {"", []}
defp build_filter_clause(filters) when is_list(filters) do
filters
# Start indexing from 3 because 1 and 2 placeholders are already used in vector and full-text queries
|> Enum.with_index(3)
|> Enum.reduce({"", []}, fn {%{"column" => column, "op" => op, "value" => value}, idx}, {clauses, params} ->
clause = "#{column} #{op} $#{idx}"
{[clause | clauses], [value | params]}
end)
|> format_clause()
end
defp format_clause({[], _params}), do: {"", []}
defp format_clause({clauses, params}) do
{Enum.join(Enum.reverse(clauses), " AND "), Enum.reverse(params)}
end
defp maybe_where(""), do: ""
defp maybe_where(filters) when is_binary(filters) and filters != "", do: "WHERE #{filters}"
defp maybe_where_and(""), do: ""
defp maybe_where_and(filters) when is_binary(filters) and filters != "", do: "AND #{filters}"
defp build_select_fields_query(nil), do: "*"
defp build_select_fields_query(select_fields) when is_list(select_fields) do
select_fields
# Filter id in case user passes as a select field (id is already included in the query)
|> Enum.filter(&(&1 != "id"))
|> Enum.join(", ")
end
defp build_query_and_args(query_string, query_vector, hybrid_query, vector_query, full_text_query, filter_params) do
cond do
not is_nil(query_string) and length(query_vector) > 0 ->
{hybrid_query, [query_vector, query_string | filter_params]}
length(query_vector) > 0 ->
{vector_query, [query_vector | filter_params]}
not is_nil(query_string) ->
{full_text_query, [query_string | filter_params]}
true ->
raise ArgumentError, "Both search query_string and query_vector are empty"
end
end
defp result_to_map(%Result{columns: columns, rows: rows}) do
cols = Enum.map(columns, &String.to_atom/1)
Enum.map(rows, fn row ->
cols
|> Enum.zip(row)
|> Map.new()
end)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment