-
-
Save mdouze/430a67fbe0937482a1fd537e14c51af0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "aabd9a5b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import faiss \n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc497586", | |
"metadata": {}, | |
"source": [ | |
"This script demonstrates how to add/remove elements from an IVF dataset in a rolling fashion. \n", | |
"The key is to use a `Hashtable` as `DirectMap` type and remove with `IDSelectorArray`. Removal cost is then proportional to the number of elements to remove instead of number of elements in the dataset. \n", | |
"\n", | |
"For this example we maintain a dataset of 500 elements, and do several search steps where we remove the 100 first elements and add 100 elements in the back. We do `nstep` steps of this kind. The query vectors remain the same. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "f421fe1b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# generate a dataset \n", | |
"\n", | |
"nstep = 10\n", | |
"ds = SyntheticDataset(32, 4000, 500 + 100 * nstep, 50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "b286219e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# train the index. The training remains fixed in the following. \n", | |
"\n", | |
"trained_index = faiss.index_factory(ds.d, \"IVF50,PQ8np\")\n", | |
"trained_index.train(ds.get_train())\n", | |
"trained_index.nprobe = 5" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "059074ad", | |
"metadata": {}, | |
"source": [ | |
"# Reference results\n", | |
"Construct an index from scratch with the relevant vectors and perform the search. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "0624251c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# compute reference results \n", | |
"\n", | |
"ref = []\n", | |
"xb = ds.get_database()\n", | |
"for step in range(nstep):\n", | |
" # construct the index\n", | |
" index = faiss.clone_index(trained_index)\n", | |
" subset = np.arange(500) + step * 100\n", | |
" index.add_with_ids(xb[subset], subset)\n", | |
" # search in the index \n", | |
" ref.append(index.search(ds.get_queries(), 10))\n", | |
" # throw away the index\n", | |
" del index" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9367ad6e", | |
"metadata": {}, | |
"source": [ | |
"# Rolling dataset\n", | |
"Here we maintain a current index and remove a slice of 100 elements at each step and add 100 after. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "d783ce3d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index = faiss.clone_index(trained_index)\n", | |
"index.set_direct_map_type(faiss.DirectMap.Hashtable)\n", | |
"\n", | |
"# initial index state\n", | |
"subset = np.arange(500)\n", | |
"index.add_with_ids(xb[subset], subset)\n", | |
"\n", | |
"xb = ds.get_database()\n", | |
"for step in range(nstep):\n", | |
"\n", | |
" # perform the search and compare with reference result\n", | |
" Dnew, Inew = index.search(ds.get_queries(), 10)\n", | |
" Dref, Iref = ref[step]\n", | |
" np.testing.assert_array_almost_equal(Dref, Dnew)\n", | |
" assert np.all(Inew == Iref) \n", | |
" \n", | |
" # remove the slice to drop\n", | |
" to_remove = np.arange(100) + step * 100\n", | |
" sel = faiss.IDSelectorArray(len(to_remove), faiss.swig_ptr(to_remove))\n", | |
" # from 1.7.3 this can be simplified to \n", | |
" # sel = faiss.IDSelectorArray(to_remove)\n", | |
" index.remove_ids(sel)\n", | |
" \n", | |
" # add new slice \n", | |
" to_add = np.arange(100) + step * 100 + 500\n", | |
" index.add_with_ids(xb[to_add], to_add)\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e7068aaa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment