-
-
Save mdouze/f87ac3b5c66a6a0c3cfb8fbb59ff52e8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"dataExplorerConfig": {}, | |
"bento_stylesheets": { | |
"bento/extensions/flow/main.css": true, | |
"bento/extensions/kernel_selector/main.css": true, | |
"bento/extensions/kernel_ui/main.css": true, | |
"bento/extensions/new_kernel/main.css": true, | |
"bento/extensions/system_usage/main.css": true, | |
"bento/extensions/theme/main.css": true | |
}, | |
"kernelspec": { | |
"display_name": "faiss", | |
"language": "python", | |
"name": "bento_kernel_faiss", | |
"cinder_runtime": true, | |
"ipyflow_runtime": false, | |
"metadata": { | |
"kernel_name": "bento_kernel_faiss", | |
"nightly_builds": true, | |
"fbpkg_supported": true, | |
"cinder_runtime": true, | |
"ipyflow_runtime": false, | |
"is_prebuilt": true | |
} | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3" | |
}, | |
"last_server_session_id": "8cdc2722-e2f4-49c2-82ce-7b40aed1711c", | |
"last_kernel_id": "e1b97110-fd19-4ff9-8883-7bdcc77b540d", | |
"last_base_url": "https://devvm4950.lla0.facebook.com:8090/", | |
"last_msg_id": "4fe82319-62cfc4e47fcf7dedc0ae8d6e_289", | |
"outputWidgetContext": {} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"collapsed": false, | |
"originalKey": "a535f3bf-5519-41e2-a352-34b43ccb951e", | |
"requestMsgId": "ca0aa03e-b19d-49ac-b77f-57175ee3ede6", | |
"customOutput": null, | |
"executionStartTime": 1669987916800, | |
"executionStopTime": 1669987916827 | |
}, | |
"source": [ | |
"import os\n", | |
"import numpy as np\n", | |
"import scipy.sparse\n", | |
"import faiss\n", | |
"from multiprocessing.pool import ThreadPool\n", | |
"from faiss.contrib import clustering\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "dec12320-0ae9-4f44-bb52-faa9a2835f1d", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"This script demonstrates how to cluster vectors that are composed of a dense\n", | |
"part of dimension d1 and a sparse part of dimension d2 where d2 >> d1. \n", | |
"The centroids are represented as full dense vectors. \n", | |
"\n", | |
"The implementation relies on the `clustering.DatasetAssign` object, that abstracts\n", | |
"away the representation of the vectors to cluster. The `clustering` module contains \n", | |
"a pure Python implementation of `kmeans` that can consume this `DatasetAssign`. " | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "ed4f3166-4207-4a1b-bbf3-e4a1c1ae3e00", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"## Sparse-dense assignment class" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "3db6459e-d306-48b8-9df7-73ab9c257c73", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "3db6459e-d306-48b8-9df7-73ab9c257c73", | |
"customOutput": null, | |
"executionStartTime": 1669989526646, | |
"executionStopTime": 1669989526725 | |
}, | |
"source": [ | |
"def sparse_dense_assign_to_dense(xdense, xsparse, xb, xq_norms, xb_norms):\n", | |
" \"\"\" assignment function for hstack(xdense, xsparse), xdense and xb are dense, \n", | |
" xsparse is a CSR matrix. It uses uses a matrix multiplication. \n", | |
" The squared norms can be provided if available.\n", | |
" \"\"\"\n", | |
" nq, ddense = xdense.shape\n", | |
" assert nq == xsparse.shape[0]\n", | |
" dsparse = xsparse.shape[1]\n", | |
" assert ddense + dsparse == xb.shape[1]\n", | |
" nb = xb.shape[0]\n", | |
" d2 = xb_norms - 2 * xdense @ xb[:, :ddense].T \n", | |
" d2 -= 2 * xsparse @ xb[:, ddense:].T\n", | |
" I = d2.argmin(axis=1)\n", | |
" D = d2.ravel()[I + np.arange(nq) * nb] + xq_norms.ravel()\n", | |
" return D, I\n", | |
"\n", | |
"\n", | |
"def sparse_dense_assign_to_dense_blocks(\n", | |
" xdense, xsparse, xb, xq_norms=None, xb_norms=None, qbs=16384, bbs=16384, nt=None): \n", | |
" \"\"\"\n", | |
" decomposes the sparse_assign_to_dense function into blocks to avoid a\n", | |
" possible memory blow up. Can be run in multithreaded mode, because scipy's\n", | |
" sparse-dense matrix multiplication is single-threaded.\n", | |
" \"\"\"\n", | |
" nq = xdense.shape[0]\n", | |
" assert nq == xsparse.shape[0]\n", | |
" nb = xb.shape[0]\n", | |
" assert xq_norms is not None\n", | |
" # prepare result arrays\n", | |
" D = np.empty(nq, dtype=\"float32\")\n", | |
" D.fill(np.inf)\n", | |
" I = -np.ones(nq, dtype=int)\n", | |
"\n", | |
" if xb_norms is None:\n", | |
" xb_norms = (xb ** 2).sum(1)\n", | |
"\n", | |
" def handle_query_block(i):\n", | |
" xdense_block = xdense[i : i + qbs]\n", | |
" xsparse_block = xsparse[i : i + qbs]\n", | |
" xq_norms_block = xq_norms[i : i + qbs]\n", | |
" Iblock = I[i : i + qbs]\n", | |
" Dblock = D[i : i + qbs]\n", | |
" for j in range(0, nb, bbs):\n", | |
" Di, Ii = sparse_dense_assign_to_dense(\n", | |
" xdense_block, xsparse_block,\n", | |
" xb[j : j + bbs],\n", | |
" xq_norms=xq_norms_block,\n", | |
" xb_norms=xb_norms[j : j + bbs],\n", | |
" )\n", | |
" if j == 0:\n", | |
" Iblock[:] = Ii\n", | |
" Dblock[:] = Di\n", | |
" else:\n", | |
" mask = Di < Dblock\n", | |
" Iblock[mask] = Ii[mask] + j\n", | |
" Dblock[mask] = Di[mask]\n", | |
"\n", | |
" if nt == 0 or nt == 1 or nq <= qbs:\n", | |
" list(map(handle_query_block, range(0, nq, qbs)))\n", | |
" else:\n", | |
" pool = ThreadPool(nt)\n", | |
" pool.map(handle_query_block, range(0, nq, qbs))\n", | |
" \n", | |
" return D, I \n", | |
"" | |
], | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "b779f69b-fa8c-41b2-b054-7b7fccbbfca4", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "b779f69b-fa8c-41b2-b054-7b7fccbbfca4", | |
"customOutput": null, | |
"executionStartTime": 1669989534601, | |
"executionStopTime": 1669989534697 | |
}, | |
"source": [ | |
"class DatasetAssignDenseSparse(clustering.DatasetAssign):\n", | |
" \"\"\"Wrapper for a matrix that is the concatenation of a dense and\n", | |
" a sparse set of columns.\"\"\"\n", | |
"\n", | |
" def __init__(self, xdense, xsparse):\n", | |
" assert xsparse.__class__ == scipy.sparse.csr_matrix\n", | |
" self.xsparse = xsparse\n", | |
" self.xdense = xdense\n", | |
" self.squared_norms = (xdense**2).sum(1)[:, None] + np.array(\n", | |
" xsparse.power(2).sum(1)\n", | |
" )\n", | |
"\n", | |
" def get_subset(self, indices):\n", | |
" return np.hstack(\n", | |
" (self.xdense[indices], np.array(self.xsparse[indices].todense()))\n", | |
" )\n", | |
"\n", | |
" def count(self):\n", | |
" return self.xdense.shape[0]\n", | |
"\n", | |
" def dim(self):\n", | |
" return self.xdense.shape[1] + self.xsparse.shape[1]\n", | |
"\n", | |
" def perform_search(self, centroids):\n", | |
" return sparse_dense_assign_to_dense_blocks(\n", | |
" self.xdense, self.xsparse, centroids, xq_norms=self.squared_norms, nt=None\n", | |
" )\n", | |
"\n", | |
" def assign_to(self, centroids, weights=None):\n", | |
" D, I = self.perform_search(centroids)\n", | |
"\n", | |
" I = I.ravel()\n", | |
" D = D.ravel()\n", | |
" n = self.xdense.shape[0]\n", | |
" if weights is None:\n", | |
" weights = np.ones(n, dtype=\"float32\")\n", | |
" nc = len(centroids)\n", | |
" m = scipy.sparse.csc_matrix((weights, I, np.arange(n + 1)), shape=(nc, n))\n", | |
" sum_per_centroid = np.hstack((\n", | |
" np.array(m @ self.xdense),\n", | |
" np.array((m * self.xsparse).todense())\n", | |
" ))\n", | |
" return I, D, sum_per_centroid" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "30aff58d-3489-433e-a7b3-6a9505570472", | |
"showInput": false, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "30aff58d-3489-433e-a7b3-6a9505570472", | |
"customOutput": null, | |
"executionStartTime": 1669979495769, | |
"executionStopTime": 1669979495796 | |
}, | |
"source": [ | |
"## Test" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "9b99006c-f957-4062-a4d8-5ab8bfb4912b", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "9b99006c-f957-4062-a4d8-5ab8bfb4912b", | |
"customOutput": null, | |
"executionStartTime": 1669989575097, | |
"executionStopTime": 1669989582002 | |
}, | |
"source": [ | |
"# generate the dense data \n", | |
"ds = SyntheticDataset(64, 100000, 0, 0, seed=123)\n", | |
"# generate the sparse data (higher dimensional)\n", | |
"ds2 = SyntheticDataset(1000, 100000, 0, 0, seed=234)" | |
], | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "1e1f92f4-7bd4-49f1-93d5-f819976f93cd", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "1e1f92f4-7bd4-49f1-93d5-f819976f93cd", | |
"customOutput": null, | |
"executionStartTime": 1669987931634, | |
"executionStopTime": 1669987931642 | |
}, | |
"source": [ | |
"# dense part \n", | |
"xdense = ds.get_train()" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "b9d17515-55fd-4b0e-915a-ba7418fa570d", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "b9d17515-55fd-4b0e-915a-ba7418fa570d", | |
"customOutput": null, | |
"executionStartTime": 1669989610627, | |
"executionStopTime": 1669989611163 | |
}, | |
"source": [ | |
"# sparse part: make sure that less than 1% of the data is non-zero\n", | |
"xsparse = ds2.get_train()\n", | |
"mask = (np.abs(xsparse) < 0.9997)\n", | |
"xsparse[mask] = 0\n", | |
"# xsparse = scipy.sparse.csr_matrix(xsparse[mask])" | |
], | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "d102e8ee-be8f-40d4-86f4-b0c1bd372ccc", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"### reference run" | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "8af26ee9-bed6-4ecf-acc8-9f82c68e63d9", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"Reference run where the sparse part is just treated as dense and concatenated to the dense part " | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "ab7ad5b9-7b5e-41cf-ac00-54c6a1bd803c", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "ab7ad5b9-7b5e-41cf-ac00-54c6a1bd803c", | |
"customOutput": null, | |
"executionStartTime": 1669987935271, | |
"executionStopTime": 1669987937794 | |
}, | |
"source": [ | |
"data = clustering.DatasetAssign(np.hstack((xdense, xsparse))) # this is the normal dense DatesetAssign\n", | |
"clusters, iteration_stats = clustering.kmeans(100, data, niter=10, return_stats=True)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"compute centroids\r", | |
" Iteration 9 (2.18 s, search 2.16 s): objective=3.44424e+06 imbalance=1.050 nsplit=0\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "af71eb0f-5307-442c-a93a-056d070920c9", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "af71eb0f-5307-442c-a93a-056d070920c9", | |
"customOutput": null, | |
"executionStartTime": 1669987940432, | |
"executionStopTime": 1669987940448 | |
}, | |
"source": [ | |
"iteration_stats[-1][\"obj\"]" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "3444237.0" | |
}, | |
"metadata": { | |
"bento_obj_id": "139938904559184" | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"originalKey": "c698a3ae-6068-456d-a1a7-03edd820c581", | |
"showInput": false, | |
"customInput": null | |
}, | |
"source": [ | |
"### Run with real sparse support\n", | |
"\n", | |
"Use the `DatasetAssignDenseSparse` object to handle the dense + sparse clustering." | |
], | |
"attachments": {} | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "370d9678-6cd5-423b-ab1c-9cfd3bc44e59", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "370d9678-6cd5-423b-ab1c-9cfd3bc44e59", | |
"customOutput": null, | |
"executionStartTime": 1669989114536, | |
"executionStopTime": 1669989124370 | |
}, | |
"source": [ | |
"xsparse_csr = scipy.sparse.csr_matrix(xsparse)\n", | |
"data = DatasetAssignDenseSparse(xdense, xsparse_csr)\n", | |
"clusters, iteration_stats = clustering.kmeans(100, data, niter=10, return_stats=True)" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"compute centroids\r", | |
" Iteration 9 (8.44 s, search 8.42 s): objective=3.44421e+06 imbalance=1.049 nsplit=0\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"originalKey": "6fcfe188-bc58-4295-aa14-adea51bd43bd", | |
"showInput": true, | |
"customInput": null, | |
"collapsed": false, | |
"requestMsgId": "6fcfe188-bc58-4295-aa14-adea51bd43bd", | |
"customOutput": null, | |
"executionStartTime": 1669989129197, | |
"executionStopTime": 1669989129238 | |
}, | |
"source": [ | |
"iteration_stats[-1][\"obj\"]" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "3444213.0" | |
}, | |
"metadata": { | |
"bento_obj_id": "139938427419760" | |
}, | |
"execution_count": 15 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment