Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created January 11, 2023 14:41
Show Gist options
  • Save mdouze/c7653aaa8c3549b28bad75bd67543d34 to your computer and use it in GitHub Desktop.
Save mdouze/c7653aaa8c3549b28bad75bd67543d34 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "13aacdfd",
"metadata": {},
"source": [
"This notebook demonstrates how to find the top-k _farthest_ points form a query vector $q$ in L2 distance\n",
"\n",
"(for max inner product or cosine this is trivially achieved by searching query $-q$)\n",
"\n",
"We use the decomposition \n",
"$$\n",
"\\begin{eqnarray}\n",
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\| q - x \\| & = & \n",
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\| x \\|^2 -2\\left< x, q \\right> \\\\\n",
"& = & \n",
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\left< x', q' \\right> \\\\\n",
"\\end{eqnarray}\n",
"$$\n",
"where \n",
"$$ \n",
"x' = [ -2x , \\|x\\|^2 ] \\textrm{ and } q' = [ q, 1]\n",
"$$\n",
"Therefore, it is sufficient to index vectors with one additional dimension and query in the same way. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ce25a656",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import faiss\n",
"\n",
"from faiss.contrib.datasets import SyntheticDataset"
]
},
{
"cell_type": "markdown",
"id": "503b15e5",
"metadata": {},
"source": [
"## Generate data "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a28a40b1",
"metadata": {},
"outputs": [],
"source": [
"# a dataset with 200 database vectors and 100 queries\n",
"ds = SyntheticDataset(32, 0, 200, 100)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "bd042b9c",
"metadata": {},
"outputs": [],
"source": [
"xq = ds.get_queries()\n",
"xb = ds.get_database()\n",
"d = ds.d # data dimensionality\n",
"k = 10 # number of farthest points per query"
]
},
{
"cell_type": "markdown",
"id": "71d8e4ec",
"metadata": {},
"source": [
"## Ground truth"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "5c13747f",
"metadata": {},
"outputs": [],
"source": [
"dis = faiss.pairwise_distances(xq, xb)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "c4807808",
"metadata": {},
"outputs": [],
"source": [
"Iref = dis.argsort(axis=1)[:, -1:-k-1:-1]\n",
"Dref = np.take_along_axis(dis, Iref, axis=1)"
]
},
{
"cell_type": "markdown",
"id": "e67e186e",
"metadata": {},
"source": [
"## How to use a max inner product index"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "bbdab1da",
"metadata": {},
"outputs": [],
"source": [
"def augment_queries(xq): \n",
" extra_column = np.ones((len(xq), 1), dtype=xq.dtype)\n",
" return np.hstack((xq, extra_column))\n",
"\n",
"def augment_database(xb): \n",
" norms2 = (xb ** 2).sum(1)\n",
" return np.hstack((-2 * xb, norms2[:, None]))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "eba8e2e6",
"metadata": {},
"outputs": [],
"source": [
"index = faiss.IndexFlatIP(d + 1)\n",
"index.add(augment_database(xb))"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "0c6af0c7",
"metadata": {},
"outputs": [],
"source": [
"Dnew, Inew = index.search(augment_queries(xq), k)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "c44f77d7",
"metadata": {},
"outputs": [],
"source": [
"# correct the distances since by re-adding the query norm\n",
"norms2_xq = (xq ** 2).sum(1)\n",
"Dnew += norms2_xq[:, None]"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "c78d7ea8",
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_array_equal(Iref, Inew)\n",
"np.testing.assert_array_almost_equal(Dref, Dnew, decimal=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f903c9cb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment