-
-
Save mdouze/a8c914eb8c5c8306194ea1da48a577d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "a443d7d3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import faiss\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "69d65bf8", | |
"metadata": {}, | |
"source": [ | |
"This demonstrates the equivalent of brute-force knn computation in pure numpy. \n", | |
"It may be easier to use than the Faiss knn functions for simple use cases." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "e7f7cebd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# generate some dataset\n", | |
"\n", | |
"ds = SyntheticDataset(32, 0, 1234, 2345)\n", | |
"xq = ds.get_queries()\n", | |
"xb = ds.get_database()\n", | |
"k = 13" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "a0186ee8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def np_knn(xq, xb, k): \n", | |
" # knn function in numpy. This mimics closely what is computed in Faiss \n", | |
" # without the tiling (will OOM with too large matrices)\n", | |
" norms_xq = (xq ** 2).sum(axis=1)\n", | |
" norms_xb = (xb ** 2).sum(axis=1)\n", | |
" distances = norms_xq.reshape(-1, 1) + norms_xb -2 * xq @ xb.T \n", | |
" I = np.argpartition(distances, k, axis=1)[:, :k]\n", | |
" D = np.take_along_axis(distances, I, axis=1)\n", | |
" # unfortunately argparition does not sort the partition, so need another \n", | |
" # round of sorting\n", | |
" o = np.argsort(D, axis=1)\n", | |
" return np.take_along_axis(D, o, axis=1), np.take_along_axis(I, o, axis=1)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "73e1c60e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# reference result \n", | |
"Dref, Iref = faiss.knn(xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "545f325a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# numpy version\n", | |
"Dnp, Inp = np_knn(xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "dd028a9f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_equal(Iref, Inp)\n", | |
"np.testing.assert_almost_equal(Dref, Dnp, decimal=5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "9cc36f6d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.22 ms ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit faiss.knn(xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "3f8dd21a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"31.8 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit np_knn(xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6ea7751e", | |
"metadata": {}, | |
"source": [ | |
"Ok it's a bit slower, but Faiss is not always easy to install, so this may turn out to be useful in some cases." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0d4b82bb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment