Created
March 6, 2020 10:58
-
-
Save mdouze/33fc39927c343c4ca003f1d8f5a412ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Loading faiss with AVX2 support.\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"import faiss\n", | |
"from matplotlib import pyplot" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Whitening and searching with L2 distance is equivalent to using a Mahalnobis distance (see https://en.wikipedia.org/wiki/Mahalanobis_distance)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Generate a random matrix with Gaussian distribution\n", | |
"rs = np.random.RandomState(123)\n", | |
"\n", | |
"d = 10\n", | |
"n = 10000\n", | |
"\n", | |
"# Gaussian, identity covariance, mean 0\n", | |
"xnorm = rs.randn(n, d)\n", | |
"\n", | |
"# Gaussian, general\n", | |
"x = xnorm @ rs.rand(d, d) + rs.rand(d)\n", | |
"\n", | |
"# convert to float for Faiss\n", | |
"x = x.astype('float32')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Train PCA with whitening\n", | |
"whiten = faiss.PCAMatrix(d, d, -0.5)\n", | |
"whiten.train(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([2.7029884e+05, 2.0929883e+04, 1.5178835e+04, 9.8838896e+03,\n", | |
" 8.2295615e+03, 4.9401885e+03, 3.4985391e+03, 1.5123438e+03,\n", | |
" 1.7264275e+02, 7.5637016e+01], dtype=float32)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# look at eigenvalues\n", | |
"faiss.vector_to_array(whiten.eigenvalues)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# apply whitening\n", | |
"xt = whiten.apply_py(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 5289, 5793, 9512],\n", | |
" [ 1, 1570, 6202, 3440],\n", | |
" [ 2, 2601, 8350, 7803]])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# search on original vectors\n", | |
"index = faiss.IndexFlatL2(d)\n", | |
"index.add(x)\n", | |
"index.search(x[:3], 4)[1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1908, 1604, 4671],\n", | |
" [ 1, 742, 3440, 3498],\n", | |
" [ 2, 318, 2879, 9319]])" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# search on whitened vectors\n", | |
"index = faiss.IndexFlatL2(d)\n", | |
"index.add(xt)\n", | |
"index.search(xt[:3], 4)[1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1908, 1604, 4671],\n", | |
" [ 1, 742, 3440, 3498],\n", | |
" [ 2, 318, 2879, 9319]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# search in Gaussian space with identity covariance\n", | |
"index = faiss.IndexFlatL2(d)\n", | |
"index.add(xnorm.astype('float32'))\n", | |
"index.search(xnorm[:3].astype('float32'), 4)[1]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The results should be the same in the whitened space as with identity covariance. \n", | |
"\n", | |
"## Degenerate example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[2.9622341e+05 2.0840871e+04 1.6759289e+04 9.2783945e+03 8.6274736e+03\n", | |
" 5.2020957e+03 3.0101091e+03 1.3651075e+03 1.4849922e+02 1.0448833e-02]\n" | |
] | |
} | |
], | |
"source": [ | |
"# make rank-deficient matrix\n", | |
"M = rs.rand(d, d)\n", | |
"M[1] = M[0]\n", | |
"\n", | |
"x = xnorm @ M + rs.rand(d)\n", | |
"x = x.astype('float32')\n", | |
"whiten = faiss.PCAMatrix(d, d, -0.5)\n", | |
"whiten.train(x)\n", | |
"\n", | |
"print(faiss.vector_to_array(whiten.eigenvalues))\n", | |
"xt = whiten.apply_py(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 118, 5289, 1908],\n", | |
" [ 1, 742, 3498, 7666],\n", | |
" [ 2, 9319, 318, 4849]])" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"index = faiss.IndexFlatL2(d)\n", | |
"index.add(xt)\n", | |
"index.search(xt[:3], 4)[1]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Note that the results are now different. The last eigenvalue should be 0, but to numerical approximations it is +/- epsilon. If the sign happens to be negative, the result is even worse. In practice, it is best to restrict the output dimension to exclude too small eigenvalues: build the `PCAMatrix` with an output dimension smaller than `d`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment