-
-
Save mdouze/8f43d37037d0ca19327539c0f8227f8e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"dataExplorerConfig": {}, | |
"bento_stylesheets": { | |
"bento/extensions/flow/main.css": true, | |
"bento/extensions/kernel_selector/main.css": true, | |
"bento/extensions/kernel_ui/main.css": true, | |
"bento/extensions/new_kernel/main.css": true, | |
"bento/extensions/system_usage/main.css": true, | |
"bento/extensions/theme/main.css": true | |
}, | |
"kernelspec": { | |
"display_name": "faiss", | |
"language": "python", | |
"name": "bento_kernel_faiss", | |
"cinder_runtime": true, | |
"ipyflow_runtime": false, | |
"metadata": { | |
"kernel_name": "bento_kernel_faiss", | |
"nightly_builds": true, | |
"fbpkg_supported": true, | |
"cinder_runtime": true, | |
"is_prebuilt": true | |
} | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3" | |
}, | |
"last_server_session_id": "a7f8086e-3ffb-4f04-a2ce-9ffee45d1195", | |
"last_kernel_id": "04538af2-8e8f-415e-9181-4292da304f29", | |
"last_base_url": "https://devvm4950.lla0.facebook.com:8090/", | |
"last_msg_id": "bce63df2-6e0e4559b0b25d6328287199_907", | |
"outputWidgetContext": {} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"collapsed": false, | |
"originalKey": "88d7cc7f-346a-435c-84da-80511d627406", | |
"requestMsgId": "3fab6f33-7629-4ad0-b7f7-bf8c4e162419", | |
"customOutput": null, | |
"executionStartTime": 1668618830063, | |
"executionStopTime": 1668618830137 | |
}, | |
"source": [ | |
"import faiss\n", | |
"import numpy as np\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "da48f7c7-8998-43fa-8e1d-be121ad1ba83", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"This script demonstrates how to manually train an IVFPQ index enclosed in a OPQ pre-processor. \n", | |
"This can be useful, for example, if there are pre-trained centroids handy for the data distribution. \n", | |
"\n", | |
"This is also implemented in the function [train_ivf_index_with_2level](https://github.com/facebookresearch/faiss/blob/main/contrib/clustering.py#L86)." | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "a717594a-9c7a-445a-b695-d0d05bded846", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "fe830167-13b2-4c03-bada-d27f9ed0ad6e", | |
"customOutput": null, | |
"executionStartTime": 1668618891221, | |
"executionStopTime": 1668618891276 | |
}, | |
"source": [ | |
"ds = SyntheticDataset(32, 2000, 1000, 5)" | |
], | |
"execution_count": 36, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "15ff17fd-d717-40b7-b422-1d655cb56152", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"## Reference training" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "0fd64847-e422-423d-9941-4d6824f6d9a0", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "6df9bbf3-3306-4279-b57f-6feb7dd6c61b", | |
"customOutput": null, | |
"executionStartTime": 1668618893012, | |
"executionStopTime": 1668618893081 | |
}, | |
"source": [ | |
"index = faiss.index_factory(ds.d, \"OPQ4,IVF100,PQ4x8np\") " | |
], | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "fa64e321-3770-4939-b01f-494c5c13a44a", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "4ca127fa-c1c3-454a-890b-339b4d60450c", | |
"customOutput": null, | |
"executionStartTime": 1668618893840, | |
"executionStopTime": 1668618895312 | |
}, | |
"source": [ | |
"index.train(ds.get_train())" | |
], | |
"execution_count": 38, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "06fc60ce-c9a6-4bc3-8b77-5b2db1ca38b6", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "e0c25fa2-b78c-4086-9fd5-61f0a82fd104", | |
"customOutput": null, | |
"executionStartTime": 1668618896680, | |
"executionStopTime": 1668618896686 | |
}, | |
"source": [ | |
"index.sa_encode(ds.get_queries())" | |
], | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "array([[ 80, 67, 20, 182, 212],\n [ 70, 168, 24, 124, 40],\n [ 28, 55, 4, 47, 200],\n [ 91, 211, 76, 98, 2],\n [ 40, 5, 43, 44, 125]], dtype=uint8)" | |
}, | |
"metadata": { | |
"bento_obj_id": "140012483718560" | |
}, | |
"execution_count": 39 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "eb03733f-b3b3-4918-b23f-a675e7a0d4c0", | |
"showInput": false, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "b966c252-0380-45db-be57-7c7b343da4c9", | |
"customOutput": null, | |
"executionStartTime": 1668618046246, | |
"executionStopTime": 1668618046304 | |
}, | |
"source": [ | |
"## Manual training" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "a5c910e6-f6cb-463e-b133-42855e72d371", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "b2f2cd36-a64b-4cb0-b52d-0921dd69c550", | |
"customOutput": null, | |
"executionStartTime": 1668618899370, | |
"executionStopTime": 1668618899435 | |
}, | |
"source": [ | |
"index2 = faiss.index_factory(ds.d, \"OPQ4,IVF100,PQ4x8np\") " | |
], | |
"execution_count": 40, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "77e541c7-969d-4298-a34d-c4f657954015", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "f28567a4-5c29-4bc5-ab96-f297b54f0ba3", | |
"customOutput": null, | |
"executionStartTime": 1668618901301, | |
"executionStopTime": 1668618903027 | |
}, | |
"source": [ | |
"# manually train the pretransform \n", | |
"vt = faiss.downcast_VectorTransform(index2.chain.at(0))\n", | |
"vt.train(ds.get_train())" | |
], | |
"execution_count": 41, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "6127f2b2-946e-48cb-bc48-fd60cb80ad1e", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "f4a4fcbf-0d7d-4aca-8b7e-6abb65565ff1", | |
"customOutput": null, | |
"executionStartTime": 1668618903816, | |
"executionStopTime": 1668618903872 | |
}, | |
"source": [ | |
"# apply the transformation to the training set\n", | |
"xt_transformed = vt.apply(ds.get_train())" | |
], | |
"execution_count": 42, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "4e576cf0-ef36-4dba-b753-f1bb1f6e9392", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "9a40c97a-652f-40e1-acf3-ed15909efbd8", | |
"customOutput": null, | |
"executionStartTime": 1668618904877, | |
"executionStopTime": 1668618904921 | |
}, | |
"source": [ | |
"# train the coarse quantizer\n", | |
"index_ivf = faiss.downcast_index(index2.index)\n", | |
"km = faiss.Kmeans(index_ivf.d, index_ivf.nlist, niter=index_ivf.cp.niter)\n", | |
"km.train(xt_transformed)\n", | |
"index_ivf.quantizer.add(km.centroids) \n", | |
"# after k-means the centroids are added to the quantizer. This applies also to more complex quantizers like HNSW " | |
], | |
"execution_count": 43, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "7c02d549-2d78-47b0-8780-bdf66c1ff15b", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "3dfb6ee6-4464-4abf-9d73-cd51585bd096", | |
"customOutput": null, | |
"executionStartTime": 1668618906410, | |
"executionStopTime": 1668618906421 | |
}, | |
"source": [ | |
"# train the PQ quantizer\n", | |
"index_ivf.train(xt_transformed)" | |
], | |
"execution_count": 44, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "2471e3fc-4950-4ab2-b3e2-eb8e9f008715", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "ef072391-8d5b-488f-a5f8-398d765a0d13", | |
"customOutput": null, | |
"executionStartTime": 1668618908082, | |
"executionStopTime": 1668618908149 | |
}, | |
"source": [ | |
"# check / set the is_trained flags\n", | |
"assert index_ivf.is_trained\n", | |
"index2.is_trained = True" | |
], | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "3a1102f2-5e7f-4b13-a738-6f6de0b37dd8", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "44b6bbf4-6b8b-48e5-b67b-e436a8a9e0eb", | |
"customOutput": null, | |
"executionStartTime": 1668618908956, | |
"executionStopTime": 1668618908967 | |
}, | |
"source": [ | |
"# check that the encoding is the same \n", | |
"index2.sa_encode(ds.get_queries())" | |
], | |
"execution_count": 46, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "array([[ 80, 67, 20, 182, 212],\n [ 70, 168, 24, 124, 40],\n [ 28, 55, 4, 47, 200],\n [ 91, 211, 76, 98, 2],\n [ 40, 5, 43, 44, 125]], dtype=uint8)" | |
}, | |
"metadata": { | |
"bento_obj_id": "140012502714528" | |
}, | |
"execution_count": 46 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "121b80f7-4305-491f-8882-b65affe4cbc1", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"Note that the OPQ training depends on Lapack functions that are not always reproducuible, see [Reproducibility with multiple threads](https://github.com/facebookresearch/faiss/wiki/Threads-and-asynchronous-calls#reproducibility-with-multiple-threads). \n", | |
"Therefore, it is a matter of chance that in this case, the result is the same. \n", | |
"The k-means training is reproducible though." | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "7853a3f0-f2a3-4755-b647-6b0f7dacbfeb", | |
"showInput": true, | |
"customInput": null | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment