-
-
Save mdouze/551ef6fa0722f2acf58fa2c6fce732d6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "a7de037f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch \n", | |
"import numpy as np\n", | |
"\n", | |
"from faiss.contrib.datasets import SyntheticDataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "599a59b0", | |
"metadata": {}, | |
"source": [ | |
"This demonstrates the equivalent of brute-force knn computation in pytroch. \n", | |
"It may be easier to use than the Faiss knn function on GPU because no resource object needs to be constructed." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "ebe996c4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# generate some dataset\n", | |
"\n", | |
"ds = SyntheticDataset(32, 0, 1234, 2345)\n", | |
"xq = ds.get_queries()\n", | |
"xb = ds.get_database()\n", | |
"k = 13" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "44c5978f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def torch_knn(xq, xb, k): \n", | |
" # knn function in pytorch. This mimics closely what is computed in Faiss \n", | |
" # without the tiling (will OOM with too large matrices)\n", | |
" norms_xq = (xq ** 2).sum(axis=1)\n", | |
" norms_xb = (xb ** 2).sum(axis=1)\n", | |
" distances = norms_xq.reshape(-1, 1) + norms_xb -2 * xq @ xb.T \n", | |
" return torch.topk(distances, k, largest=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "49c2de6a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# reference result \n", | |
"Dref, Iref = faiss.knn(xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "b6039a9a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# torch version\n", | |
"Dtorch, Itorch = torch_knn(torch.from_numpy(xq), torch.from_numpy(xb), k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "2bde1b31", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_equal(Iref, Itorch)\n", | |
"np.testing.assert_almost_equal(Dref, Dtorch, decimal=5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "68b61eee", | |
"metadata": {}, | |
"source": [ | |
"## On GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "8af5de33", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"res = faiss.StandardGpuResources()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "cc773514", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Dref_gpu, Iref_gpu = faiss.knn_gpu(res, xq, xb, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "ef1d2eb6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Dtorch_gpu, Itorch_gpu = torch_knn(\n", | |
" torch.from_numpy(xq).cuda(), \n", | |
" torch.from_numpy(xb).cuda(), \n", | |
" k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "e14c7bee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_equal(Iref_gpu, Itorch_gpu.cpu())\n", | |
"np.testing.assert_almost_equal(Dref_gpu, Dtorch.cpu(), decimal=5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4bb48e45", | |
"metadata": {}, | |
"source": [ | |
"## Speed\n", | |
"how does the Torch distance function compare to the Faiss implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "ed7e68f4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# enable torch input tensors to Faiss knn function\n", | |
"import faiss.contrib.torch_utils" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "2bfb4a23", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds_large = SyntheticDataset(32, 0, 20000, 30000)\n", | |
"xq_large = torch.from_numpy(ds_large.get_queries()).cuda()\n", | |
"xb_large = torch.from_numpy(ds_large.get_database()).cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"id": "f1b7485e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"14.5 ms ± 47.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit faiss.knn_gpu(res, xq_large, xb_large, k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "0f64bf38", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"105 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit torch_knn(xq_large, xb_large, k)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b86e1dfb", | |
"metadata": {}, | |
"source": [ | |
"So the Torch version is a lot slower (and uses 6G of GPU RAM in this case), but it collects gradients if needed. " | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment