Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created November 2, 2022 07:32
Show Gist options
  • Save mdouze/430a67fbe0937482a1fd537e14c51af0 to your computer and use it in GitHub Desktop.
Save mdouze/430a67fbe0937482a1fd537e14c51af0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "aabd9a5b",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import faiss \n",
"from faiss.contrib.datasets import SyntheticDataset"
]
},
{
"cell_type": "markdown",
"id": "bc497586",
"metadata": {},
"source": [
"This script demonstrates how to add/remove elements from an IVF dataset in a rolling fashion. \n",
"The key is to use a `Hashtable` as `DirectMap` type and remove with `IDSelectorArray`. Removal cost is then proportional to the number of elements to remove instead of number of elements in the dataset. \n",
"\n",
"For this example we maintain a dataset of 500 elements, and do several search steps where we remove the 100 first elements and add 100 elements in the back. We do `nstep` steps of this kind. The query vectors remain the same. "
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "f421fe1b",
"metadata": {},
"outputs": [],
"source": [
"# generate a dataset \n",
"\n",
"nstep = 10\n",
"ds = SyntheticDataset(32, 4000, 500 + 100 * nstep, 50)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "b286219e",
"metadata": {},
"outputs": [],
"source": [
"# train the index. The training remains fixed in the following. \n",
"\n",
"trained_index = faiss.index_factory(ds.d, \"IVF50,PQ8np\")\n",
"trained_index.train(ds.get_train())\n",
"trained_index.nprobe = 5"
]
},
{
"cell_type": "markdown",
"id": "059074ad",
"metadata": {},
"source": [
"# Reference results\n",
"Construct an index from scratch with the relevant vectors and perform the search. "
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "0624251c",
"metadata": {},
"outputs": [],
"source": [
"# compute reference results \n",
"\n",
"ref = []\n",
"xb = ds.get_database()\n",
"for step in range(nstep):\n",
" # construct the index\n",
" index = faiss.clone_index(trained_index)\n",
" subset = np.arange(500) + step * 100\n",
" index.add_with_ids(xb[subset], subset)\n",
" # search in the index \n",
" ref.append(index.search(ds.get_queries(), 10))\n",
" # throw away the index\n",
" del index"
]
},
{
"cell_type": "markdown",
"id": "9367ad6e",
"metadata": {},
"source": [
"# Rolling dataset\n",
"Here we maintain a current index and remove a slice of 100 elements at each step and add 100 after. "
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "d783ce3d",
"metadata": {},
"outputs": [],
"source": [
"index = faiss.clone_index(trained_index)\n",
"index.set_direct_map_type(faiss.DirectMap.Hashtable)\n",
"\n",
"# initial index state\n",
"subset = np.arange(500)\n",
"index.add_with_ids(xb[subset], subset)\n",
"\n",
"xb = ds.get_database()\n",
"for step in range(nstep):\n",
"\n",
" # perform the search and compare with reference result\n",
" Dnew, Inew = index.search(ds.get_queries(), 10)\n",
" Dref, Iref = ref[step]\n",
" np.testing.assert_array_almost_equal(Dref, Dnew)\n",
" assert np.all(Inew == Iref) \n",
" \n",
" # remove the slice to drop\n",
" to_remove = np.arange(100) + step * 100\n",
" sel = faiss.IDSelectorArray(len(to_remove), faiss.swig_ptr(to_remove))\n",
" # from 1.7.3 this can be simplified to \n",
" # sel = faiss.IDSelectorArray(to_remove)\n",
" index.remove_ids(sel)\n",
" \n",
" # add new slice \n",
" to_add = np.arange(100) + step * 100 + 500\n",
" index.add_with_ids(xb[to_add], to_add)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7068aaa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment