-
-
Save mdouze/08b3aaec6bf4a82994d89cb955ac1c57 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "b09d1c49", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import faiss\n", | |
"\n", | |
"\n", | |
"from faiss.contrib.datasets import SyntheticDataset\n", | |
"from faiss.contrib.inspect_tools import get_invlist" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fb460716", | |
"metadata": {}, | |
"source": [ | |
"Make a synthetic dataset, construct an IVFPQ index. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "01d10f34", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds = SyntheticDataset(32, 10000, 1000, 100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "2734735d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index = faiss.index_factory(ds.d, \"IVF100,PQ4x8np\")\n", | |
"index.train(ds.get_train())\n", | |
"index.add(ds.get_database())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "34748b57", | |
"metadata": {}, | |
"source": [ | |
"Reference search results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "2e261ac5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index.nprobe = 4\n", | |
"Dref, Iref = index.search(ds.get_queries(), 10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "050ab9e2", | |
"metadata": {}, | |
"source": [ | |
"## Reproduce tables \n", | |
"\n", | |
"IVFPQ search is based on precomputed look-up tables. \n", | |
"This demonstrates how to compute them. Note that the c++ version optionally uses a slightly faster way of precomputing them, see https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVFPQ.cpp#L334" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "a845ab01", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# set some variables\n", | |
"xq = ds.get_queries()\n", | |
"nq, d = xq.shape\n", | |
"nprobe = index.nprobe" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "bc75737b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# coarse quantization\n", | |
"Dcoarse, Icoarse = index.quantizer.search(xq, nprobe)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "95bb9057", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(100, 4, 32)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# compute residuals \n", | |
"residuals = xq[:, None, :] - index.quantizer.reconstruct_batch(Icoarse.ravel()).reshape(nq, nprobe, d)\n", | |
"residuals.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"id": "c74d8276", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# call compute_distance_tables on the residual tables \n", | |
"\n", | |
"pq = index.pq \n", | |
"dis_tab = np.zeros((nq, nprobe, pq.M, pq.ksub), dtype='float32')\n", | |
"dis_tab[:] = np.nan\n", | |
"pq.compute_distance_tables(\n", | |
" nq * nprobe, \n", | |
" faiss.swig_ptr(residuals), \n", | |
" faiss.swig_ptr(dis_tab) \n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "45d674a2", | |
"metadata": {}, | |
"source": [ | |
"## Search with precomputed table \n", | |
"\n", | |
"Pure Python implementation of search from look-up tables. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "2558fa5c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# this is a schematic implementation of \n", | |
"\n", | |
"Dnew = []\n", | |
"Inew = []\n", | |
"K = 10 # number of results \n", | |
"for i in range(nq): \n", | |
" all_dis = [] # all distances for this query vector\n", | |
" all_ids = [] # all ids \n", | |
" for j in range(nprobe): \n", | |
" ids, codes = get_invlist(index.invlists, int(Icoarse[i, j]))\n", | |
" # codes is an array of size l by pq.M. If pq.nbits != 8 the encoding \n", | |
" # is a bit more complex, see \n", | |
" # https://github.com/facebookresearch/faiss/wiki/Python-C---code-snippets#how-can-i-get-access-to-non-8-bit-quantization-code-entries-in-pq--ivfpq--aq-\n", | |
" tab = dis_tab[i, j]\n", | |
" # distances for this inverted list\n", | |
" distances = np.sum([\n", | |
" tab[m, codes[:, m]]\n", | |
" for m in range(pq.M)\n", | |
" ], axis=0)\n", | |
" # collect results. In the C++ implementation the top-K results \n", | |
" # are maintained with a heap rather than stored completely\n", | |
" all_dis.append(distances)\n", | |
" all_ids.append(ids)\n", | |
" # get the top-K \n", | |
" all_dis = np.hstack(all_dis)\n", | |
" all_ids = np.hstack(all_ids)\n", | |
" order = np.argsort(all_dis)[:K]\n", | |
" Dnew.append(all_dis[order])\n", | |
" Inew.append(all_ids[order])\n", | |
" \n", | |
"Dnew = np.vstack(Dnew)\n", | |
"Inew = np.vstack(Inew)\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"id": "203379fa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert (Inew == Iref).all()\n", | |
"np.testing.assert_allclose(Dref, Dnew, rtol=1e-5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7362b79d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment