-
-
Save mdouze/c7653aaa8c3549b28bad75bd67543d34 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "13aacdfd", | |
"metadata": {}, | |
"source": [ | |
"This notebook demonstrates how to find the top-k _farthest_ points form a query vector $q$ in L2 distance\n", | |
"\n", | |
"(for max inner product or cosine this is trivially achieved by searching query $-q$)\n", | |
"\n", | |
"We use the decomposition \n", | |
"$$\n", | |
"\\begin{eqnarray}\n", | |
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\| q - x \\| & = & \n", | |
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\| x \\|^2 -2\\left< x, q \\right> \\\\\n", | |
"& = & \n", | |
"\\mathrm{argmax}_{x \\in \\mathcal{X}} \\left< x', q' \\right> \\\\\n", | |
"\\end{eqnarray}\n", | |
"$$\n", | |
"where \n", | |
"$$ \n", | |
"x' = [ -2x , \\|x\\|^2 ] \\textrm{ and } q' = [ q, 1]\n", | |
"$$\n", | |
"Therefore, it is sufficient to index vectors with one additional dimension and query in the same way. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "ce25a656", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import faiss\n", | |
"\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "503b15e5", | |
"metadata": {}, | |
"source": [ | |
"## Generate data " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a28a40b1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# a dataset with 200 database vectors and 100 queries\n", | |
"ds = SyntheticDataset(32, 0, 200, 100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "bd042b9c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"xq = ds.get_queries()\n", | |
"xb = ds.get_database()\n", | |
"d = ds.d # data dimensionality\n", | |
"k = 10 # number of farthest points per query" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "71d8e4ec", | |
"metadata": {}, | |
"source": [ | |
"## Ground truth" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "5c13747f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dis = faiss.pairwise_distances(xq, xb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "c4807808", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Iref = dis.argsort(axis=1)[:, -1:-k-1:-1]\n", | |
"Dref = np.take_along_axis(dis, Iref, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e67e186e", | |
"metadata": {}, | |
"source": [ | |
"## How to use a max inner product index" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "bbdab1da", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def augment_queries(xq): \n", | |
" extra_column = np.ones((len(xq), 1), dtype=xq.dtype)\n", | |
" return np.hstack((xq, extra_column))\n", | |
"\n", | |
"def augment_database(xb): \n", | |
" norms2 = (xb ** 2).sum(1)\n", | |
" return np.hstack((-2 * xb, norms2[:, None]))\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "eba8e2e6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index = faiss.IndexFlatIP(d + 1)\n", | |
"index.add(augment_database(xb))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "0c6af0c7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Dnew, Inew = index.search(augment_queries(xq), k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"id": "c44f77d7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# correct the distances since by re-adding the query norm\n", | |
"norms2_xq = (xq ** 2).sum(1)\n", | |
"Dnew += norms2_xq[:, None]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "c78d7ea8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_array_equal(Iref, Inew)\n", | |
"np.testing.assert_array_almost_equal(Dref, Dnew, decimal=4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f903c9cb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment