Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created February 7, 2022 17:23
Show Gist options
  • Save mdouze/f3a05bff5186c1874a77356452297357 to your computer and use it in GitHub Desktop.
Save mdouze/f3a05bff5186c1874a77356452297357 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {},
"outputs": [],
"source": [
"def quantize_LUT(LUT, bias):\n",
" \"\"\" LUT quantization to uint8 and bias to uint16\n",
" LUT:\n",
" - 2D size (M, ksub): single matrix per probe\n",
" - 3D size (nprobe, M, ksub): separate LUT per probe\n",
" bias: \n",
" - None: bias is 0\n",
" - size (nprobe): one bias per probe\n",
" Output: \n",
" - LUTq uint8 version of the LUT\n",
" - biasq (or None): uint16 version of the LUT\n",
" - a, b: scalars to approximate the true distance\n",
" \"\"\"\n",
" \n",
" M, ksub = LUT.shape[-2:]\n",
" if bias is None: \n",
" assert LUT.ndim == 2\n",
" mins, maxs = LUT.min(axis=1, keepdims=True), LUT.max(axis=1, keepdims=True)\n",
" max_span = (maxs - mins).max()\n",
" max_span_dis = (maxs - mins).sum()\n",
" a = min(255 / max_span, 65535 / max_span_dis)\n",
" b = mins.sum()\n",
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8')\n",
" return LUTq, None, a, b\n",
" elif LUT.ndim == 2: \n",
" mins, maxs = LUT.min(axis=1, keepdims=True), LUT.max(axis=1, keepdims=True)\n",
" bias_min, bias_max = bias.min(), bias.max()\n",
" max_span_LUT = (maxs - mins).max()\n",
" # max possible distance after min subtraction\n",
" max_span_dis = (bias_max - bias_min) + (maxs - mins).sum()\n",
" a = min(255 / max_span_LUT, 65535 / max_span_dis) \n",
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8') \n",
" biasq = np.floor((bias - bias_min) * a + 0.5).astype('uint16')\n",
" b = mins.sum() + bias_min\n",
" return LUTq, biasq, a, b\n",
" else: \n",
" assert LUT.ndim == 3\n",
" nprobe = len(bias)\n",
" assert LUT.shape[0] == nprobe\n",
" mins, maxs = LUT.min(axis=2, keepdims=True), LUT.max(axis=2, keepdims=True)\n",
" # mins, maxs: size (nprobe, M, 1)\n",
" bias_min, bias_max = bias.min(), bias.max()\n",
" max_span_LUT = (maxs - mins).max()\n",
" span_dis_per_probe = (bias - bias_min) + (maxs[:, :, 0] - mins[:, :, 0]).sum(axis=1)\n",
" max_span_dis = span_dis_per_probe.max()\n",
" a = min(255 / max_span_LUT, 65535 / max_span_dis)\n",
" LUTq = np.floor((LUT - mins) * a + 0.5).astype('uint8') \n",
" \n",
" # biases absorb the LUT minima \n",
" bias2 = bias + mins[:, :, 0].sum(axis=1)\n",
" b = bias2.min()\n",
" biasq = np.floor((bias2 - b) * a + 0.5).astype('uint16')\n",
"\n",
" return LUTq, biasq, a, b\n"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [],
"source": [
"# distance functions that use the LUTs\n",
"\n",
"def compute_dis_float(codes, LUT, bias): \n",
" nprobe, nt, M = codes.shape\n",
" dis = np.zeros((nprobe, nt), dtype='float32')\n",
" if bias is not None: \n",
" dis[:] = bias.reshape(-1, 1)\n",
" \n",
" if LUT.ndim == 2: \n",
" LUTp = LUT\n",
" \n",
" for p in range(nprobe): \n",
" if LUT.ndim == 3: \n",
" LUTp = LUT[p]\n",
" \n",
" for i in range(nt): \n",
" dis[p, i] += LUTp[np.arange(M), codes[p, i]].sum()\n",
" \n",
" return dis\n",
" \n",
"\n",
"def compute_dis_quant(codes, LUT, bias, a, b):\n",
" nprobe, nt, M = codes.shape\n",
" dis = np.zeros((nprobe, nt), dtype='uint16')\n",
" if bias is not None: \n",
" dis[:] = bias.reshape(-1, 1)\n",
" \n",
" if LUT.ndim == 2: \n",
" LUTp = LUT\n",
" \n",
" for p in range(nprobe): \n",
" if LUT.ndim == 3: \n",
" LUTp = LUT[p]\n",
" \n",
" for i in range(nt): \n",
" dis[p, i] += LUTp[np.arange(M), codes[p, i]].astype('uint16').sum()\n",
" \n",
" return dis / a + b"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {},
"outputs": [],
"source": [
"# testbed\n",
"\n",
"ksub = 16\n",
"M = 20\n",
"nprobe = 10\n",
"nt = 200\n",
"\n",
"rs = np.random.RandomState(123)\n",
"codes = rs.randint(ksub, size=(nprobe, nt, M)).astype('uint8')"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.0003627200721823546, 923, 1077)"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Metric IP, no residual \n",
"LUT = rs.rand(M, ksub).astype('float32')\n",
"bias = None\n",
"\n",
"dis_ref = compute_dis_float(codes, LUT, bias)\n",
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n",
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n",
"\n",
"# error measure \n",
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.00026479566539748734, 1121, 879)"
]
},
"execution_count": 149,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Metric IP, residual \n",
"LUT = rs.rand(M, ksub).astype('float32')\n",
"bias = rs.rand(nprobe).astype('float32') \n",
"bias *= 10 # bias will typically have larger magnitude\n",
"\n",
"dis_ref = compute_dis_float(codes, LUT, bias)\n",
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n",
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n",
"\n",
"# error measure\n",
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [],
"source": [
"# Metric L2, no residual \n",
"# same as IP no residual \n"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.00023117881439414453, 1032, 968)"
]
},
"execution_count": 151,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Metric L2, residual\n",
"LUT = rs.rand(nprobe, M, ksub).astype('float32')\n",
"bias = rs.rand(nprobe).astype('float32')\n",
"bias *= 10\n",
"\n",
"dis_ref = compute_dis_float(codes, LUT, bias)\n",
"LUTq, biasq, a, b = quantize_LUT(LUT, bias)\n",
"dis_new = compute_dis_quant(codes, LUTq, biasq, a, b)\n",
"\n",
"# error measiure\n",
"(np.abs(dis_new - dis_ref).sum() / dis_ref.sum()).mean(), (dis_new > dis_ref).sum(), (dis_new < dis_ref).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment