Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created November 16, 2022 17:21
Show Gist options
  • Save mdouze/8f43d37037d0ca19327539c0f8227f8e to your computer and use it in GitHub Desktop.
Save mdouze/8f43d37037d0ca19327539c0f8227f8e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"dataExplorerConfig": {},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"display_name": "faiss",
"language": "python",
"name": "bento_kernel_faiss",
"cinder_runtime": true,
"ipyflow_runtime": false,
"metadata": {
"kernel_name": "bento_kernel_faiss",
"nightly_builds": true,
"fbpkg_supported": true,
"cinder_runtime": true,
"is_prebuilt": true
}
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"last_server_session_id": "a7f8086e-3ffb-4f04-a2ce-9ffee45d1195",
"last_kernel_id": "04538af2-8e8f-415e-9181-4292da304f29",
"last_base_url": "https://devvm4950.lla0.facebook.com:8090/",
"last_msg_id": "bce63df2-6e0e4559b0b25d6328287199_907",
"outputWidgetContext": {}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"metadata": {
"collapsed": false,
"originalKey": "88d7cc7f-346a-435c-84da-80511d627406",
"requestMsgId": "3fab6f33-7629-4ad0-b7f7-bf8c4e162419",
"customOutput": null,
"executionStartTime": 1668618830063,
"executionStopTime": 1668618830137
},
"source": [
"import faiss\n",
"import numpy as np\n",
"from faiss.contrib.datasets import SyntheticDataset"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "da48f7c7-8998-43fa-8e1d-be121ad1ba83",
"showInput": false,
"customInput": null
},
"source": [
"This script demonstrates how to manually train an IVFPQ index enclosed in a OPQ pre-processor. \n",
"This can be useful, for example, if there are pre-trained centroids handy for the data distribution. \n",
"\n",
"This is also implemented in the function [train_ivf_index_with_2level](https://github.com/facebookresearch/faiss/blob/main/contrib/clustering.py#L86)."
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "a717594a-9c7a-445a-b695-d0d05bded846",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "fe830167-13b2-4c03-bada-d27f9ed0ad6e",
"customOutput": null,
"executionStartTime": 1668618891221,
"executionStopTime": 1668618891276
},
"source": [
"ds = SyntheticDataset(32, 2000, 1000, 5)"
],
"execution_count": 36,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "15ff17fd-d717-40b7-b422-1d655cb56152",
"showInput": false,
"customInput": null
},
"source": [
"## Reference training"
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "0fd64847-e422-423d-9941-4d6824f6d9a0",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "6df9bbf3-3306-4279-b57f-6feb7dd6c61b",
"customOutput": null,
"executionStartTime": 1668618893012,
"executionStopTime": 1668618893081
},
"source": [
"index = faiss.index_factory(ds.d, \"OPQ4,IVF100,PQ4x8np\") "
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "fa64e321-3770-4939-b01f-494c5c13a44a",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "4ca127fa-c1c3-454a-890b-339b4d60450c",
"customOutput": null,
"executionStartTime": 1668618893840,
"executionStopTime": 1668618895312
},
"source": [
"index.train(ds.get_train())"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "06fc60ce-c9a6-4bc3-8b77-5b2db1ca38b6",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "e0c25fa2-b78c-4086-9fd5-61f0a82fd104",
"customOutput": null,
"executionStartTime": 1668618896680,
"executionStopTime": 1668618896686
},
"source": [
"index.sa_encode(ds.get_queries())"
],
"execution_count": 39,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([[ 80, 67, 20, 182, 212],\n [ 70, 168, 24, 124, 40],\n [ 28, 55, 4, 47, 200],\n [ 91, 211, 76, 98, 2],\n [ 40, 5, 43, 44, 125]], dtype=uint8)"
},
"metadata": {
"bento_obj_id": "140012483718560"
},
"execution_count": 39
}
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "eb03733f-b3b3-4918-b23f-a675e7a0d4c0",
"showInput": false,
"customInput": null,
"collapsed": false,
"requestMsgId": "b966c252-0380-45db-be57-7c7b343da4c9",
"customOutput": null,
"executionStartTime": 1668618046246,
"executionStopTime": 1668618046304
},
"source": [
"## Manual training"
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "a5c910e6-f6cb-463e-b133-42855e72d371",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "b2f2cd36-a64b-4cb0-b52d-0921dd69c550",
"customOutput": null,
"executionStartTime": 1668618899370,
"executionStopTime": 1668618899435
},
"source": [
"index2 = faiss.index_factory(ds.d, \"OPQ4,IVF100,PQ4x8np\") "
],
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "77e541c7-969d-4298-a34d-c4f657954015",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "f28567a4-5c29-4bc5-ab96-f297b54f0ba3",
"customOutput": null,
"executionStartTime": 1668618901301,
"executionStopTime": 1668618903027
},
"source": [
"# manually train the pretransform \n",
"vt = faiss.downcast_VectorTransform(index2.chain.at(0))\n",
"vt.train(ds.get_train())"
],
"execution_count": 41,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "6127f2b2-946e-48cb-bc48-fd60cb80ad1e",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "f4a4fcbf-0d7d-4aca-8b7e-6abb65565ff1",
"customOutput": null,
"executionStartTime": 1668618903816,
"executionStopTime": 1668618903872
},
"source": [
"# apply the transformation to the training set\n",
"xt_transformed = vt.apply(ds.get_train())"
],
"execution_count": 42,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "4e576cf0-ef36-4dba-b753-f1bb1f6e9392",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "9a40c97a-652f-40e1-acf3-ed15909efbd8",
"customOutput": null,
"executionStartTime": 1668618904877,
"executionStopTime": 1668618904921
},
"source": [
"# train the coarse quantizer\n",
"index_ivf = faiss.downcast_index(index2.index)\n",
"km = faiss.Kmeans(index_ivf.d, index_ivf.nlist, niter=index_ivf.cp.niter)\n",
"km.train(xt_transformed)\n",
"index_ivf.quantizer.add(km.centroids) \n",
"# after k-means the centroids are added to the quantizer. This applies also to more complex quantizers like HNSW "
],
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "7c02d549-2d78-47b0-8780-bdf66c1ff15b",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "3dfb6ee6-4464-4abf-9d73-cd51585bd096",
"customOutput": null,
"executionStartTime": 1668618906410,
"executionStopTime": 1668618906421
},
"source": [
"# train the PQ quantizer\n",
"index_ivf.train(xt_transformed)"
],
"execution_count": 44,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "2471e3fc-4950-4ab2-b3e2-eb8e9f008715",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "ef072391-8d5b-488f-a5f8-398d765a0d13",
"customOutput": null,
"executionStartTime": 1668618908082,
"executionStopTime": 1668618908149
},
"source": [
"# check / set the is_trained flags\n",
"assert index_ivf.is_trained\n",
"index2.is_trained = True"
],
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "3a1102f2-5e7f-4b13-a738-6f6de0b37dd8",
"showInput": true,
"customInput": null,
"collapsed": false,
"requestMsgId": "44b6bbf4-6b8b-48e5-b67b-e436a8a9e0eb",
"customOutput": null,
"executionStartTime": 1668618908956,
"executionStopTime": 1668618908967
},
"source": [
"# check that the encoding is the same \n",
"index2.sa_encode(ds.get_queries())"
],
"execution_count": 46,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([[ 80, 67, 20, 182, 212],\n [ 70, 168, 24, 124, 40],\n [ 28, 55, 4, 47, 200],\n [ 91, 211, 76, 98, 2],\n [ 40, 5, 43, 44, 125]], dtype=uint8)"
},
"metadata": {
"bento_obj_id": "140012502714528"
},
"execution_count": 46
}
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "121b80f7-4305-491f-8882-b65affe4cbc1",
"showInput": false,
"customInput": null
},
"source": [
"Note that the OPQ training depends on Lapack functions that are not always reproducuible, see [Reproducibility with multiple threads](https://github.com/facebookresearch/faiss/wiki/Threads-and-asynchronous-calls#reproducibility-with-multiple-threads). \n",
"Therefore, it is a matter of chance that in this case, the result is the same. \n",
"The k-means training is reproducible though."
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "7853a3f0-f2a3-4755-b647-6b0f7dacbfeb",
"showInput": true,
"customInput": null
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment