Skip to content

Instantly share code, notes, and snippets.

@mdouze
Last active November 2, 2022 09:32
Show Gist options
  • Save mdouze/b2e6c6144d4e06fca8287f5257f15fed to your computer and use it in GitHub Desktop.
Save mdouze/b2e6c6144d4e06fca8287f5257f15fed to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"dataExplorerConfig": {},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"kernelspec": {
"display_name": "faiss",
"language": "python",
"name": "bento_kernel_faiss",
"metadata": {
"kernel_name": "bento_kernel_faiss",
"nightly_builds": true,
"fbpkg_supported": true,
"cinder_runtime": false,
"is_prebuilt": true
}
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"last_server_session_id": "dda44ac7-0726-4b1f-b2d1-192e30571c8c",
"last_kernel_id": "9fb3f046-5812-468a-880c-1a502cb68bb9",
"last_base_url": "https://devvm4950.lla0.facebook.com:8090/",
"last_msg_id": "95b404a6-226f69fd1bf78017499d5f0a_1600",
"outputWidgetContext": {}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"originalKey": "d1240614-3a13-430c-9ad0-ea65e20de909",
"showInput": false,
"customInput": null
},
"source": [
"This script demonstrates an asymmetric search use case: \n",
"the query vectors are in full precision and the database vectors are compressed as binary vectors. \n",
"This implementation is slow, it is mainly intended to show how much accuracy can be regained with asymmetric search."
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"collapsed": false,
"originalKey": "30a3f11d-3131-4a13-a72c-bf93149e8b48",
"code_folding": [],
"hidden_ranges": [],
"requestMsgId": "30a3f11d-3131-4a13-a72c-bf93149e8b48",
"customOutput": null,
"executionStartTime": 1652091792635,
"executionStopTime": 1652091792641
},
"source": [
"import numpy as np\n",
"import faiss\n",
"from faiss.contrib.datasets import SyntheticDataset"
],
"execution_count": 42,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "96316054-488d-4b0b-b7cf-579364314c44",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "96316054-488d-4b0b-b7cf-579364314c44",
"customOutput": null,
"executionStartTime": 1652091106966,
"executionStopTime": 1652091107046
},
"source": [
"# An example dataset.\n",
"ds = SyntheticDataset(64, 1000, 2000, 200)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "98f8611b-8064-4203-928c-c2818ef6bd63",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "98f8611b-8064-4203-928c-c2818ef6bd63",
"customOutput": null,
"executionStartTime": 1652091313645,
"executionStopTime": 1652091315915
},
"source": [
"# make a binary dataset from that: ITQ rotation and dim resuction 64->32,\n",
"# followed by thresholding\n",
"binarizer = faiss.index_factory(64, \"ITQ32,LSH\")\n",
"xt = ds.get_train()\n",
"binarizer.train(xt)"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "e5363998-9508-4a0e-b9cb-c58118b3f1c9",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "e5363998-9508-4a0e-b9cb-c58118b3f1c9",
"customOutput": null,
"executionStartTime": 1652092905111,
"executionStopTime": 1652092905303
},
"source": [
"# transform query vectors and database vectors \n",
"\n",
"xb_binary = binarizer.sa_encode(ds.get_database())\n",
"xq_binary = binarizer.sa_encode(ds.get_queries())"
],
"execution_count": 85,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "31ddf16b-35f3-4535-b35d-017ad4f56546",
"showInput": false,
"customInput": null
},
"source": [
"## Baseline symmetric search"
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "c701c0c9-664b-4cc7-8e21-ff4cde48c41c",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "c701c0c9-664b-4cc7-8e21-ff4cde48c41c",
"customOutput": null,
"executionStartTime": 1652091523755,
"executionStopTime": 1652091523921
},
"source": [
"# Baseline: symmetric search in the binary domain \n",
"\n",
"index_binary = faiss.IndexBinaryFlat(32)\n",
"index_binary.add(xb_binary)\n",
"D, I = index_binary.search(xq_binary, 100) \n",
"\n",
"gt = ds.get_groundtruth()\n",
"for rank in 1, 10, 100: \n",
" recall = (I[:, :rank] == gt[:, :1]).sum() / len(gt)\n",
" print(f\"Recall at {rank}: {recall:.3f}\")"
],
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Recall at 1: 0.075\nRecall at 10: 0.250\nRecall at 100: 0.680\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "399b52cb-a9a9-451f-b215-7d313eb735b0",
"showInput": false,
"customInput": null
},
"source": [
"## Asymmetric search"
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "1a76d077-f37f-4e1f-9f85-cf10aca675c9",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "1a76d077-f37f-4e1f-9f85-cf10aca675c9",
"customOutput": null,
"executionStartTime": 1652092357581,
"executionStopTime": 1652092357751
},
"source": [
"# Prepare asymmetric search\n",
"pre_transform = binarizer.chain.at(0)\n",
"\n",
"xt_transformed = pre_transform.apply(xt)\n",
"xt_binarized = np.unpackbits(binarizer.sa_encode(xt), axis=1, bitorder='little')\n",
"d1 = xt_binarized.shape[1]\n",
"\n",
"# find the mean value represented by bits 0 and 1 for each of the 32 dimensions \n",
"mean_for_0s = np.zeros(d1, dtype='float32')\n",
"mean_for_1s = np.zeros(d1, dtype='float32')\n",
"for bit in range(d1): \n",
" mean_for_0s[bit] = xt_transformed[xt_binarized[:, bit] == 0, bit].mean()\n",
" mean_for_1s[bit] = xt_transformed[xt_binarized[:, bit] == 1, bit].mean()\n",
""
],
"execution_count": 69,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "baa21949-0c0b-4e2e-9f74-fdc816167506",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "baa21949-0c0b-4e2e-9f74-fdc816167506",
"customOutput": null,
"executionStartTime": 1652093026071,
"executionStopTime": 1652093026080
},
"source": [
"# decompress xb using the approximations \n",
"# of course we don't need access to the orginal database vectors\n",
"xb_decompressed = np.unpackbits(xb_binary, axis=1, bitorder='little') * (mean_for_1s - mean_for_0s) + mean_for_0s\n",
""
],
"execution_count": 86,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"originalKey": "bdf04305-0f04-4c73-a2c4-644e0c7e8bf3",
"showInput": true,
"customInput": null,
"code_folding": [],
"hidden_ranges": [],
"collapsed": false,
"requestMsgId": "bdf04305-0f04-4c73-a2c4-644e0c7e8bf3",
"customOutput": null,
"executionStartTime": 1652093038780,
"executionStopTime": 1652093039127
},
"source": [
"# perform queries (we need access to the original vectors) and evaluate \n",
"xq_transformed = pre_transform.apply(ds.get_queries())\n",
"D, I = faiss.knn(xq_transformed, xb_decompressed, 100) \n",
"\n",
"gt = ds.get_groundtruth()\n",
"for rank in 1, 10, 100: \n",
" recall = (I[:, :rank] == gt[:, :1]).sum() / len(gt)\n",
" print(f\"Recall at {rank}: {recall:.3f}\")"
],
"execution_count": 87,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Recall at 1: 0.080\nRecall at 10: 0.355\nRecall at 100: 0.890\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "3ecc6927-240a-4857-8ec7-0c6d8e60a2fa",
"showInput": false,
"customInput": null,
"code_folding": [],
"hidden_ranges": []
},
"source": [
"So here the asymmetric search improves the accuracy by 10% on the recall @ 10"
],
"attachments": {}
},
{
"cell_type": "code",
"metadata": {
"originalKey": "fd781be5-bed2-4631-a632-08c115cb0a89",
"showInput": true,
"customInput": null
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment