-
-
Save mdouze/b2e6c6144d4e06fca8287f5257f15fed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"dataExplorerConfig": {}, | |
"bento_stylesheets": { | |
"bento/extensions/flow/main.css": true, | |
"bento/extensions/kernel_selector/main.css": true, | |
"bento/extensions/kernel_ui/main.css": true, | |
"bento/extensions/new_kernel/main.css": true, | |
"bento/extensions/system_usage/main.css": true, | |
"bento/extensions/theme/main.css": true | |
}, | |
"kernelspec": { | |
"display_name": "faiss", | |
"language": "python", | |
"name": "bento_kernel_faiss", | |
"metadata": { | |
"kernel_name": "bento_kernel_faiss", | |
"nightly_builds": true, | |
"fbpkg_supported": true, | |
"cinder_runtime": false, | |
"is_prebuilt": true | |
} | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3" | |
}, | |
"last_server_session_id": "dda44ac7-0726-4b1f-b2d1-192e30571c8c", | |
"last_kernel_id": "9fb3f046-5812-468a-880c-1a502cb68bb9", | |
"last_base_url": "https://devvm4950.lla0.facebook.com:8090/", | |
"last_msg_id": "95b404a6-226f69fd1bf78017499d5f0a_1600", | |
"outputWidgetContext": {} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "d1240614-3a13-430c-9ad0-ea65e20de909", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"This script demonstrates an asymmetric search use case: \n", | |
"the query vectors are in full precision and the database vectors are compressed as binary vectors. \n", | |
"This implementation is slow, it is mainly intended to show how much accuracy can be regained with asymmetric search." | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"collapsed": false, | |
"originalKey": "30a3f11d-3131-4a13-a72c-bf93149e8b48", | |
"code_folding": [], | |
"hidden_ranges": [], | |
"requestMsgId": "30a3f11d-3131-4a13-a72c-bf93149e8b48", | |
"customOutput": null, | |
"executionStartTime": 1652091792635, | |
"executionStopTime": 1652091792641 | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import faiss\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
], | |
"execution_count": 42, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "96316054-488d-4b0b-b7cf-579364314c44", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "96316054-488d-4b0b-b7cf-579364314c44", | |
"customOutput": null, | |
"executionStartTime": 1652091106966, | |
"executionStopTime": 1652091107046 | |
}, | |
"source": [ | |
"# An example dataset.\n", | |
"ds = SyntheticDataset(64, 1000, 2000, 200)" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "98f8611b-8064-4203-928c-c2818ef6bd63", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "98f8611b-8064-4203-928c-c2818ef6bd63", | |
"customOutput": null, | |
"executionStartTime": 1652091313645, | |
"executionStopTime": 1652091315915 | |
}, | |
"source": [ | |
"# make a binary dataset from that: ITQ rotation and dim resuction 64->32,\n", | |
"# followed by thresholding\n", | |
"binarizer = faiss.index_factory(64, \"ITQ32,LSH\")\n", | |
"xt = ds.get_train()\n", | |
"binarizer.train(xt)" | |
], | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "e5363998-9508-4a0e-b9cb-c58118b3f1c9", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "e5363998-9508-4a0e-b9cb-c58118b3f1c9", | |
"customOutput": null, | |
"executionStartTime": 1652092905111, | |
"executionStopTime": 1652092905303 | |
}, | |
"source": [ | |
"# transform query vectors and database vectors \n", | |
"\n", | |
"xb_binary = binarizer.sa_encode(ds.get_database())\n", | |
"xq_binary = binarizer.sa_encode(ds.get_queries())" | |
], | |
"execution_count": 85, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "31ddf16b-35f3-4535-b35d-017ad4f56546", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"## Baseline symmetric search" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "c701c0c9-664b-4cc7-8e21-ff4cde48c41c", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "c701c0c9-664b-4cc7-8e21-ff4cde48c41c", | |
"customOutput": null, | |
"executionStartTime": 1652091523755, | |
"executionStopTime": 1652091523921 | |
}, | |
"source": [ | |
"# Baseline: symmetric search in the binary domain \n", | |
"\n", | |
"index_binary = faiss.IndexBinaryFlat(32)\n", | |
"index_binary.add(xb_binary)\n", | |
"D, I = index_binary.search(xq_binary, 100) \n", | |
"\n", | |
"gt = ds.get_groundtruth()\n", | |
"for rank in 1, 10, 100: \n", | |
" recall = (I[:, :rank] == gt[:, :1]).sum() / len(gt)\n", | |
" print(f\"Recall at {rank}: {recall:.3f}\")" | |
], | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Recall at 1: 0.075\nRecall at 10: 0.250\nRecall at 100: 0.680\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "399b52cb-a9a9-451f-b215-7d313eb735b0", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"## Asymmetric search" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "1a76d077-f37f-4e1f-9f85-cf10aca675c9", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "1a76d077-f37f-4e1f-9f85-cf10aca675c9", | |
"customOutput": null, | |
"executionStartTime": 1652092357581, | |
"executionStopTime": 1652092357751 | |
}, | |
"source": [ | |
"# Prepare asymmetric search\n", | |
"pre_transform = binarizer.chain.at(0)\n", | |
"\n", | |
"xt_transformed = pre_transform.apply(xt)\n", | |
"xt_binarized = np.unpackbits(binarizer.sa_encode(xt), axis=1, bitorder='little')\n", | |
"d1 = xt_binarized.shape[1]\n", | |
"\n", | |
"# find the mean value represented by bits 0 and 1 for each of the 32 dimensions \n", | |
"mean_for_0s = np.zeros(d1, dtype='float32')\n", | |
"mean_for_1s = np.zeros(d1, dtype='float32')\n", | |
"for bit in range(d1): \n", | |
" mean_for_0s[bit] = xt_transformed[xt_binarized[:, bit] == 0, bit].mean()\n", | |
" mean_for_1s[bit] = xt_transformed[xt_binarized[:, bit] == 1, bit].mean()\n", | |
"" | |
], | |
"execution_count": 69, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "baa21949-0c0b-4e2e-9f74-fdc816167506", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "baa21949-0c0b-4e2e-9f74-fdc816167506", | |
"customOutput": null, | |
"executionStartTime": 1652093026071, | |
"executionStopTime": 1652093026080 | |
}, | |
"source": [ | |
"# decompress xb using the approximations \n", | |
"# of course we don't need access to the orginal database vectors\n", | |
"xb_decompressed = np.unpackbits(xb_binary, axis=1, bitorder='little') * (mean_for_1s - mean_for_0s) + mean_for_0s\n", | |
"" | |
], | |
"execution_count": 86, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "bdf04305-0f04-4c73-a2c4-644e0c7e8bf3", | |
"showInput": true, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [], | |
"collapsed": false, | |
"requestMsgId": "bdf04305-0f04-4c73-a2c4-644e0c7e8bf3", | |
"customOutput": null, | |
"executionStartTime": 1652093038780, | |
"executionStopTime": 1652093039127 | |
}, | |
"source": [ | |
"# perform queries (we need access to the original vectors) and evaluate \n", | |
"xq_transformed = pre_transform.apply(ds.get_queries())\n", | |
"D, I = faiss.knn(xq_transformed, xb_decompressed, 100) \n", | |
"\n", | |
"gt = ds.get_groundtruth()\n", | |
"for rank in 1, 10, 100: \n", | |
" recall = (I[:, :rank] == gt[:, :1]).sum() / len(gt)\n", | |
" print(f\"Recall at {rank}: {recall:.3f}\")" | |
], | |
"execution_count": 87, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Recall at 1: 0.080\nRecall at 10: 0.355\nRecall at 100: 0.890\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "3ecc6927-240a-4857-8ec7-0c6d8e60a2fa", | |
"showInput": false, | |
"customInput": null, | |
"code_folding": [], | |
"hidden_ranges": [] | |
}, | |
"source": [ | |
"So here the asymmetric search improves the accuracy by 10% on the recall @ 10" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "fd781be5-bed2-4631-a632-08c115cb0a89", | |
"showInput": true, | |
"customInput": null | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment