Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created October 2, 2023 20:27
Show Gist options
  • Save mdouze/08b3aaec6bf4a82994d89cb955ac1c57 to your computer and use it in GitHub Desktop.
Save mdouze/08b3aaec6bf4a82994d89cb955ac1c57 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"id": "b09d1c49",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import faiss\n",
"\n",
"\n",
"from faiss.contrib.datasets import SyntheticDataset\n",
"from faiss.contrib.inspect_tools import get_invlist"
]
},
{
"cell_type": "markdown",
"id": "fb460716",
"metadata": {},
"source": [
"Make a synthetic dataset, construct an IVFPQ index. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "01d10f34",
"metadata": {},
"outputs": [],
"source": [
"ds = SyntheticDataset(32, 10000, 1000, 100)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2734735d",
"metadata": {},
"outputs": [],
"source": [
"index = faiss.index_factory(ds.d, \"IVF100,PQ4x8np\")\n",
"index.train(ds.get_train())\n",
"index.add(ds.get_database())"
]
},
{
"cell_type": "markdown",
"id": "34748b57",
"metadata": {},
"source": [
"Reference search results"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2e261ac5",
"metadata": {},
"outputs": [],
"source": [
"index.nprobe = 4\n",
"Dref, Iref = index.search(ds.get_queries(), 10)"
]
},
{
"cell_type": "markdown",
"id": "050ab9e2",
"metadata": {},
"source": [
"## Reproduce tables \n",
"\n",
"IVFPQ search is based on precomputed look-up tables. \n",
"This demonstrates how to compute them. Note that the c++ version optionally uses a slightly faster way of precomputing them, see https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVFPQ.cpp#L334"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a845ab01",
"metadata": {},
"outputs": [],
"source": [
"# set some variables\n",
"xq = ds.get_queries()\n",
"nq, d = xq.shape\n",
"nprobe = index.nprobe"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "bc75737b",
"metadata": {},
"outputs": [],
"source": [
"# coarse quantization\n",
"Dcoarse, Icoarse = index.quantizer.search(xq, nprobe)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "95bb9057",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(100, 4, 32)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# compute residuals \n",
"residuals = xq[:, None, :] - index.quantizer.reconstruct_batch(Icoarse.ravel()).reshape(nq, nprobe, d)\n",
"residuals.shape"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "c74d8276",
"metadata": {},
"outputs": [],
"source": [
"# call compute_distance_tables on the residual tables \n",
"\n",
"pq = index.pq \n",
"dis_tab = np.zeros((nq, nprobe, pq.M, pq.ksub), dtype='float32')\n",
"dis_tab[:] = np.nan\n",
"pq.compute_distance_tables(\n",
" nq * nprobe, \n",
" faiss.swig_ptr(residuals), \n",
" faiss.swig_ptr(dis_tab) \n",
")"
]
},
{
"cell_type": "markdown",
"id": "45d674a2",
"metadata": {},
"source": [
"## Search with precomputed table \n",
"\n",
"Pure Python implementation of search from look-up tables. "
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "2558fa5c",
"metadata": {},
"outputs": [],
"source": [
"# this is a schematic implementation of \n",
"\n",
"Dnew = []\n",
"Inew = []\n",
"K = 10 # number of results \n",
"for i in range(nq): \n",
" all_dis = [] # all distances for this query vector\n",
" all_ids = [] # all ids \n",
" for j in range(nprobe): \n",
" ids, codes = get_invlist(index.invlists, int(Icoarse[i, j]))\n",
" # codes is an array of size l by pq.M. If pq.nbits != 8 the encoding \n",
" # is a bit more complex, see \n",
" # https://github.com/facebookresearch/faiss/wiki/Python-C---code-snippets#how-can-i-get-access-to-non-8-bit-quantization-code-entries-in-pq--ivfpq--aq-\n",
" tab = dis_tab[i, j]\n",
" # distances for this inverted list\n",
" distances = np.sum([\n",
" tab[m, codes[:, m]]\n",
" for m in range(pq.M)\n",
" ], axis=0)\n",
" # collect results. In the C++ implementation the top-K results \n",
" # are maintained with a heap rather than stored completely\n",
" all_dis.append(distances)\n",
" all_ids.append(ids)\n",
" # get the top-K \n",
" all_dis = np.hstack(all_dis)\n",
" all_ids = np.hstack(all_ids)\n",
" order = np.argsort(all_dis)[:K]\n",
" Dnew.append(all_dis[order])\n",
" Inew.append(all_ids[order])\n",
" \n",
"Dnew = np.vstack(Dnew)\n",
"Inew = np.vstack(Inew)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "203379fa",
"metadata": {},
"outputs": [],
"source": [
"assert (Inew == Iref).all()\n",
"np.testing.assert_allclose(Dref, Dnew, rtol=1e-5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7362b79d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment