-
-
Save bogdansolga/7658cc6b47c4b5fefc821f7d00c4336e to your computer and use it in GitHub Desktop.
Postgres Hybrid Search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule PostgresHybridSearch do | |
@moduledoc """ | |
Postgres Hybrid Search | |
Loosely based on: | |
- https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search_rrf.py | |
- https://github.com/Azure-Samples/rag-postgres-openai-python/blob/e30ea96ca11ca6578ca38d3428594bd98d704900/src/fastapi_app/postgres_searcher.py#L2 | |
- https://supabase.com/docs/guides/ai/hybrid-search | |
- https://github.com/toranb/rag-n-drop/blob/main/lib/demo/section.ex#L30 | |
""" | |
import Ecto.Query | |
alias Ecto.Adapters.SQL | |
alias Postgrex.Result | |
@doc """ | |
Setup: | |
1. Create a tsvector column in your table ("generated" column recommended based on one or more text columns) | |
2. Create a GIN index on the tsvector generated column | |
3. Create a HNSW index on the vector embedding column | |
4. Generate embeddings for your documents and store them in the vector column (e.g. using BERT, etc via Bumblebee or OpenAI's API) | |
4. Generate embeddings for your search query and pass to query_embedding | |
Example Usage: | |
```elixir | |
hybrid_search( | |
YourApp.Repo, | |
YourApp.Documents.Document, | |
%{tsvector_column: "content_tsvector", query_string: "What is a cat?"}, | |
%{vector_column: "content_embedding", query_embedding: [0.1, 0.2, 0.3, ...]} | |
) | |
``` | |
""" | |
def hybrid_search( | |
repo, | |
schema, | |
%{tsvector_column: tsvector_column, query_string: query_string}, | |
%{vector_column: vector_column, query_embedding: query_embedding}, | |
filters \\ nil, | |
select_fields \\ nil, | |
match_count \\ 10, | |
full_text_weight \\ 1.0, | |
vector_weight \\ 1.0, | |
rrf_k \\ 50 | |
) | |
when is_binary(query_string) and is_list(query_embedding) and is_integer(match_count) and is_float(full_text_weight) and | |
is_float(vector_weight) and is_integer(rrf_k) do | |
table_name = schema.__schema__(:source) | |
{where_filter, filter_params} = build_filter_clause(filters) | |
vector_query = build_vector_query(table_name, vector_column, where_filter, match_count) | |
full_text_query = build_full_text_query(table_name, tsvector_column, where_filter, match_count) | |
hybrid_query = | |
build_hybrid_query( | |
table_name, | |
vector_query, | |
full_text_query, | |
select_fields, | |
match_count, | |
full_text_weight, | |
vector_weight, | |
rrf_k | |
) | |
{sql, params} = | |
build_query_and_args(query_string, query_embedding, hybrid_query, vector_query, full_text_query, filter_params) | |
repo | |
|> SQL.query!(sql, params) | |
|> result_to_map() | |
|> Enum.sort_by(& &1.score, &>=/2) | |
|> Enum.take(match_count) | |
end | |
defp build_vector_query(table_name, vector_column, filters, match_count) do | |
""" | |
SELECT id, RANK() OVER (ORDER BY #{vector_column} <=> $1) AS rank | |
FROM #{table_name} | |
#{maybe_where(filters)} | |
ORDER BY #{vector_column} <=> $1 | |
LIMIT LEAST(#{match_count}, 20) * 2 | |
""" | |
end | |
defp build_full_text_query(table_name, tsvector_column, filters, match_count) do | |
""" | |
SELECT id, RANK() OVER (ORDER BY ts_rank_cd(#{tsvector_column}, query) DESC) AS rank | |
FROM #{table_name}, websearch_to_tsquery('english', $2) query | |
WHERE #{tsvector_column} @@ query #{maybe_where_and(filters)} | |
ORDER BY ts_rank_cd(#{tsvector_column}, query) DESC | |
LIMIT LEAST(#{match_count}, 20) * 2 | |
""" | |
end | |
defp build_hybrid_query( | |
table_name, | |
vector_query, | |
full_text_query, | |
select_fields, | |
match_count, | |
full_text_weight, | |
vector_weight, | |
rrf_k | |
) do | |
select_fields_query = build_select_fields_query(select_fields) | |
""" | |
WITH vector_search AS ( | |
#{vector_query} | |
), | |
fulltext_search AS ( | |
#{full_text_query} | |
) | |
SELECT | |
COALESCE(1.0 / (#{rrf_k} + vector_search.rank), 0.0) * #{vector_weight} + COALESCE(1.0 / (#{rrf_k} + fulltext_search.rank), 0.0) * #{full_text_weight} AS score, | |
COALESCE(vector_search.id, fulltext_search.id) AS id, | |
#{select_fields_query} | |
FROM vector_search | |
FULL OUTER JOIN fulltext_search ON vector_search.id = fulltext_search.id | |
LEFT JOIN #{table_name} ON vector_search.id = #{table_name}.id | |
ORDER BY score DESC | |
LIMIT LEAST(#{match_count}, 20) * 2 | |
""" | |
end | |
defp build_filter_clause(nil), do: {"", []} | |
defp build_filter_clause(filters) when is_list(filters) do | |
filters | |
# Start indexing from 3 because 1 and 2 placeholders are already used in vector and full-text queries | |
|> Enum.with_index(3) | |
|> Enum.reduce({"", []}, fn {%{"column" => column, "op" => op, "value" => value}, idx}, {clauses, params} -> | |
clause = "#{column} #{op} $#{idx}" | |
{[clause | clauses], [value | params]} | |
end) | |
|> format_clause() | |
end | |
defp format_clause({[], _params}), do: {"", []} | |
defp format_clause({clauses, params}) do | |
{Enum.join(Enum.reverse(clauses), " AND "), Enum.reverse(params)} | |
end | |
defp maybe_where(""), do: "" | |
defp maybe_where(filters) when is_binary(filters) and filters != "", do: "WHERE #{filters}" | |
defp maybe_where_and(""), do: "" | |
defp maybe_where_and(filters) when is_binary(filters) and filters != "", do: "AND #{filters}" | |
defp build_select_fields_query(nil), do: "*" | |
defp build_select_fields_query(select_fields) when is_list(select_fields) do | |
select_fields | |
# Filter id in case user passes as a select field (id is already included in the query) | |
|> Enum.filter(&(&1 != "id")) | |
|> Enum.join(", ") | |
end | |
defp build_query_and_args(query_string, query_vector, hybrid_query, _vector_query, _full_text_query, filter_params) | |
when is_binary(query_string) and length(query_vector) > 0 do | |
{hybrid_query, [query_vector, query_string | filter_params]} | |
end | |
defp build_query_and_args(_query_string, query_vector, _hybrid_query, vector_query, _full_text_query, filter_params) | |
when length(query_vector) > 0 do | |
{vector_query, [query_vector | filter_params]} | |
end | |
defp build_query_and_args(query_string, _query_vector, _hybrid_query, _vector_query, full_text_query, filter_params) | |
when is_binary(query_string) do | |
{full_text_query, [query_string | filter_params]} | |
end | |
defp result_to_map(%Result{columns: columns, rows: rows}) do | |
cols = Enum.map(columns, &String.to_atom/1) | |
Enum.map(rows, fn row -> | |
cols | |
|> Enum.zip(row) | |
|> Map.new() | |
end) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment