-
-
Save mdouze/f3a05bff5186c1874a77356452297357 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 160, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def quantize_LUT(LUT, bias):\n", | |
" \"\"\" LUT quantization to uint8 and bias to uint16\n", | |
" LUT:\n", | |
" - 2D size (M, ksub): single matrix per probe\n", | |
" - 3D size (nprobe, M, ksub): separate LUT per probe\n", | |
" bias: \n", | |
" - None: bias is 0\n", | |
" - size (nprobe): one bias per probe\n", | |
" Output: \n", | |
" - LUTq uint8 version of the LUT\n", | |
" - biasq (or None): uint16 version of the LUT\n", | |
" - a, b: scalars to approximate the true distance\n", | |
" \"\"\"\n", | |
" \n", | |
" M, ksub = LUT.shape[-2:]\n", | |
" if bias is None: \n", | |
" assert LUT.ndim == 2\n", | |
" mins, maxs = LUT.min(axis=1, keepdims=True), LUT.max(axis=1, keepdims=True)\n", | |
" max_span = (maxs - mins).max()\n", | |
" max_span_dis = (maxs - mins).sum()\n", | |
" a = min(255 / max_span, 65535 / max_span_dis)\n", | |
" b = mins.sum()\n", | |
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8')\n", | |
" return LUTq, None, a, b\n", | |
" elif LUT.ndim == 2: \n", | |
" mins, maxs = LUT.min(axis=1, keepdims=True), LUT.max(axis=1, keepdims=True)\n", | |
" bias_min, bias_max = bias.min(), bias.max()\n", | |
" max_span_LUT = (maxs - mins).max()\n", | |
" # max possible distance after min subtraction\n", | |
" max_span_dis = (bias_max - bias_min) + (maxs - mins).sum()\n", | |
" a = min(255 / max_span_LUT, 65535 / max_span_dis) \n", | |
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8') \n", | |
" biasq = np.floor((bias - bias_min) * a + 0.5).astype('uint16')\n", | |
" b = mins.sum() + bias_min\n", | |
" return LUTq, biasq, a, b\n", | |
" else: \n", | |
" assert LUT.ndim == 3\n", | |
" nprobe = len(bias)\n", | |
" assert LUT.shape[0] == nprobe\n", | |
" mins, maxs = LUT.min(axis=2, keepdims=True), LUT.max(axis=2, keepdims=True)\n", | |
" # mins, maxs: size (nprobe, M, 1)\n", | |
" bias_min, bias_max = bias.min(), bias.max()\n", | |
" max_span_LUT = (maxs - mins).max()\n", | |
" span_dis_per_probe = (bias - bias_min) + (maxs[:, :, 0] - mins[:, :, 0]).sum(axis=1)\n", | |
" max_span_dis = span_dis_per_probe.max()\n", | |
" a = min(255 / max_span_LUT, 65535 / max_span_dis)\n", | |
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8') \n", | |
" \n", | |
" # biases absorb the LUT minima \n", | |
" bias2 = bias + mins[:, :, 0].sum(axis=1)\n", | |
" b = bias2.min()\n", | |
" biasq = np.floor((bias2 - b) * a + 0.5).astype('uint16')\n", | |
"\n", | |
" return LUTq, biasq, a, b\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# distance functions that use the LUTs\n", | |
"\n", | |
"def compute_dis_float(codes, LUT, bias): \n", | |
" nprobe, nt, M = codes.shape\n", | |
" dis = np.zeros((nprobe, nt), dtype='float32')\n", | |
" if bias is not None: \n", | |
" dis[:] = bias.reshape(-1, 1)\n", | |
" \n", | |
" if LUT.ndim == 2: \n", | |
" LUTp = LUT\n", | |
" \n", | |
" for p in range(nprobe): \n", | |
" if LUT.ndim == 3: \n", | |
" LUTp = LUT[p]\n", | |
" \n", | |
" for i in range(nt): \n", | |
" dis[p, i] += LUTp[np.arange(M), codes[p, i]].sum()\n", | |
" \n", | |
" return dis\n", | |
" \n", | |
"\n", | |
"def compute_dis_quant(codes, LUT, bias, a, b):\n", | |
" nprobe, nt, M = codes.shape\n", | |
" dis = np.zeros((nprobe, nt), dtype='uint16')\n", | |
" if bias is not None: \n", | |
" dis[:] = bias.reshape(-1, 1)\n", | |
" \n", | |
" if LUT.ndim == 2: \n", | |
" LUTp = LUT\n", | |
" \n", | |
" for p in range(nprobe): \n", | |
" if LUT.ndim == 3: \n", | |
" LUTp = LUT[p]\n", | |
" \n", | |
" for i in range(nt): \n", | |
" dis[p, i] += LUTp[np.arange(M), codes[p, i]].astype('uint16').sum()\n", | |
" \n", | |
" return dis / a + b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 147, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# testbed\n", | |
"\n", | |
"ksub = 16\n", | |
"M = 20\n", | |
"nprobe = 10\n", | |
"nt = 200\n", | |
"\n", | |
"rs = np.random.RandomState(123)\n", | |
"codes = rs.randint(ksub, size=(nprobe, nt, M)).astype('uint8')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 148, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.0003627200721823546, 923, 1077)" | |
] | |
}, | |
"execution_count": 148, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Metric IP, no residual \n", | |
"LUT = rs.rand(M, ksub).astype('float32')\n", | |
"bias = None\n", | |
"\n", | |
"dis_ref = compute_dis_float(codes, LUT, bias)\n", | |
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n", | |
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n", | |
"\n", | |
"# error measure \n", | |
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 149, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.00026479566539748734, 1121, 879)" | |
] | |
}, | |
"execution_count": 149, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Metric IP, residual \n", | |
"LUT = rs.rand(M, ksub).astype('float32')\n", | |
"bias = rs.rand(nprobe).astype('float32') \n", | |
"bias *= 10 # bias will typically have larger magnitude\n", | |
"\n", | |
"dis_ref = compute_dis_float(codes, LUT, bias)\n", | |
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n", | |
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n", | |
"\n", | |
"# error measure\n", | |
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 150, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Metric L2, no residual \n", | |
"# same as IP no residual \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 151, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.00023117881439414453, 1032, 968)" | |
] | |
}, | |
"execution_count": 151, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Metric L2, residual\n", | |
"LUT = rs.rand(nprobe, M, ksub).astype('float32')\n", | |
"bias = rs.rand(nprobe).astype('float32')\n", | |
"bias *= 10\n", | |
"\n", | |
"dis_ref = compute_dis_float(codes, LUT, bias)\n", | |
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n", | |
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n", | |
"\n", | |
"# error measiure\n", | |
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment