-
-
Save mdouze/5c32300cf3bd20946a7762f6c016e823 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def reference_accu(codes, LUT): \n", | |
" \"\"\"\n", | |
" Reference scalar accumulation function for PQ4 data. \n", | |
" The codes and LUT are laid out in a \"natural\" way.\n", | |
" nq = nb of queries\n", | |
" nb = nb of database vectors\n", | |
" nsq = nb of sub-quantizers (each size 4 bits)\n", | |
" \"\"\"\n", | |
" \n", | |
" nq, nsq, is_16 = LUT.shape\n", | |
" nb, nsq_2 = codes.shape\n", | |
" assert is_16 == 16\n", | |
" assert nsq_2 == nsq // 2 \n", | |
" accu = np.zeros((nq, nb), 'uint16')\n", | |
" for i in range(nq): \n", | |
" for j in range(nb): \n", | |
" a = np.uint16(0)\n", | |
" for sq in range(0, nsq, 2): \n", | |
" c = codes[j, sq // 2]\n", | |
" a += LUT[i, sq , c & 15].astype('uint16')\n", | |
" a += LUT[i, sq + 1, c >> 4].astype('uint16')\n", | |
" accu[i, j] = a\n", | |
" return accu" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\"\"\" \n", | |
"Generate random data for testing\n", | |
"\"\"\"\n", | |
"\n", | |
"nq = 96 # number of queries\n", | |
"nb = 384 # databse size \n", | |
"nsq = 16 # number of sub-quantizers\n", | |
"\n", | |
"rs = np.random.RandomState(123)\n", | |
"codes = rs.randint(256, size=(nb, nsq // 2)).astype('uint8')\n", | |
"LUT = rs.randint(256, size=(nq, nsq, 16)).astype('uint8')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"accu_ref = reference_accu(codes, LUT)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# First attempt -- not very efficient (loop1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def lookup_2_lanes(a, b):\n", | |
" \"\"\" The AVX2 lookup instruction, it is very crippled (because it operates on 2 * 16 bytes rather than on 32)\n", | |
" but it's the only efficient way to do lookups \"\"\"\n", | |
" blane0 = b[:16] & 15\n", | |
" blane1 = b[16:] & 15 \n", | |
" return np.hstack((a[blane0], a[blane1 + 16]))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15])" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.arange(16).reshape(2, 8).T.ravel()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"\n", | |
"def loop1_prepare_codes(codes, bbs): \n", | |
" \"\"\" \n", | |
" re-organize the code matrix so that the loop1 kernel reads all data sequentially\n", | |
" bbs = database side block size (must be a multiple of 16)\n", | |
" \"\"\"\n", | |
" nb, nsq_2 = codes.shape\n", | |
" nsq = nsq_2 * 2\n", | |
" assert nb % bbs == 0\n", | |
" assert nsq % 4 == 0\n", | |
" assert bbs % 16 == 0\n", | |
" \n", | |
" codes2 = np.zeros((nb // bbs, nsq // 4, bbs // 16, 32), 'uint8')\n", | |
" for i in range(0, nb, 16): \n", | |
" for sq in range(0, nsq, 4): \n", | |
" c01 = codes[i : i + 16, sq // 2]\n", | |
" c23 = codes[i : i + 16, sq // 2 + 1]\n", | |
" c02 = c01 & 15 | c23 << 4\n", | |
" c13 = c01 >> 4 | c23 & 0xF0 \n", | |
" cblock = codes2[i // bbs, sq // 4, (i % bbs) // 16, :] \n", | |
" cblock[:16] = c02\n", | |
" cblock[16:] = c13\n", | |
" return codes2\n", | |
" \n", | |
" \n", | |
"def loop1_prepare_LUT(LUT, qbs): \n", | |
" \"\"\" \n", | |
" Re-organize the LUT matrix so that the loop1 kernel reads data sequentially\n", | |
" \"\"\"\n", | |
" nq, nsq, is_16 = LUT.shape\n", | |
" assert is_16 == 16\n", | |
" assert nsq % 4 == 0\n", | |
" assert nq % qbs == 0\n", | |
" \n", | |
" LUT2 = np.zeros((nq // qbs, nsq // 4, qbs, 2, 32), 'uint8')\n", | |
" \n", | |
" for i in range(nq): \n", | |
" for sq in range(0, nsq, 4): \n", | |
" LUT2[i // qbs, sq // 4, i % qbs, 0, :16] = LUT[i, sq , :]\n", | |
" LUT2[i // qbs, sq // 4, i % qbs, 0, 16:] = LUT[i, sq + 1, :]\n", | |
" LUT2[i // qbs, sq // 4, i % qbs, 1, :16] = LUT[i, sq + 2, :]\n", | |
" LUT2[i // qbs, sq // 4, i % qbs, 1, 16:] = LUT[i, sq + 3, :]\n", | |
" \n", | |
" return LUT2\n", | |
"\n", | |
"loop1_nread = 0\n", | |
"\n", | |
"\n", | |
"def combine2x2(a, b): \n", | |
" return np.hstack((a[:8] + a[8:], b[:8] + b[8:]))\n", | |
" \n", | |
"def loop1_kernel(codes, LUT): \n", | |
" \"\"\" computation kernel for 1 block of the output distance matrix. \n", | |
" The block is of size bbs * qbs (determined from the sizes of the codes and LUT table)\n", | |
" \"\"\"\n", | |
" global loop1_nread\n", | |
" bb = bbs // 16\n", | |
" distances = np.zeros((qbs, bb, 16), 'uint16')\n", | |
" \n", | |
" for sq in range(0, nsq, 4): \n", | |
" ctab = codes[sq // 4, :, :] # read size (bb, 32)\n", | |
" loop1_nread += bb\n", | |
" for q in range(qbs): \n", | |
" lut0 = LUT[sq // 4, q, 0, :] # read size 32\n", | |
" lut1 = LUT[sq // 4, q, 1, :] # read size 32\n", | |
" loop1_nread += 2\n", | |
" for b in range(bb): \n", | |
" c = ctab[b]\n", | |
" chi = c >> 4\n", | |
" clo = c & 15\n", | |
" q02 = lookup_2_lanes(lut0, clo)\n", | |
" q13 = lookup_2_lanes(lut1, chi)\n", | |
" \n", | |
" q02 = q02.view('uint16')\n", | |
" q13 = q13.view('uint16')\n", | |
" \n", | |
" sev = q02 + q13 \n", | |
" sod = (q02 >> 8) + (q13 >> 8)\n", | |
" \n", | |
" distances[q, b, :] += combine2x2(sev, sod)\n", | |
" \n", | |
" distances[:, :, :8] -= distances[:, :, 8:] << 8\n", | |
" \n", | |
" perm = np.arange(16).reshape(2, 8).T.ravel() \n", | |
" \n", | |
" distances = distances[:, :, perm]\n", | |
" \n", | |
" return distances.reshape(qbs, bbs)\n", | |
"\n", | |
"\n", | |
"def loop1_accu(codes, LUT):\n", | |
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n", | |
" q_blocks, nsq_4, qbs, is_2, is_32 = LUT.shape \n", | |
" nsq = nsq_4 * 4\n", | |
" nq = q_blocks * qbs\n", | |
" assert is_2 == 2\n", | |
" assert is_32 == 32\n", | |
" \n", | |
" b_blocks, nsq_4, bb, is_32 = codes.shape\n", | |
" assert nsq == nsq_4 * 4\n", | |
" bbs = bb * 16\n", | |
" assert is_32 == 32\n", | |
" nb = b_blocks * bbs\n", | |
" \n", | |
" accu = np.zeros((nq, nb), 'uint16')\n", | |
" for i in range(0, nq, qbs): \n", | |
" for j in range(0, nb, bbs): \n", | |
" block = loop1_kernel(codes[j // bbs], LUT[i // qbs])\n", | |
" accu[i : i + qbs, j : j + bbs] = block\n", | |
" return accu\n", | |
" \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"8448" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# set some block sizes\n", | |
"bbs = 48\n", | |
"qbs = 4\n", | |
"codes2 = loop1_prepare_codes(codes, bbs)\n", | |
"LUT2 = loop1_prepare_LUT(LUT, qbs)\n", | |
"\n", | |
"loop1_nread = 0\n", | |
"accu_loop1 = loop1_accu(codes2, LUT2)\n", | |
"loop1_nread " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(accu_loop1 == accu_ref)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# SCANN implementation (loop3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loop3_nread = 0\n", | |
"\n", | |
"def loop3_kernel(codes, LUT): \n", | |
" \"\"\" \n", | |
" computes the distance matrix for nq queries and 2 * 16 database vectors \n", | |
" \"\"\"\n", | |
" nsq_2, nq, is_32 = LUT.shape\n", | |
" assert is_32 == 32\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" nsq_2, is_32 = codes.shape\n", | |
" assert nsq_2 == nsq // 2\n", | |
" \n", | |
" global loop3_nread \n", | |
" \n", | |
" accu = np.zeros((nq, 4, 16), 'uint16')\n", | |
" for sq in range(0, nsq, 2): \n", | |
" c = codes[sq // 2]\n", | |
" mask0 = c & 15\n", | |
" mask1 = c >> 4\n", | |
" loop3_nread += 1\n", | |
" for q in range(nq): \n", | |
" dic = LUT[sq // 2, q]\n", | |
" loop3_nread += 1\n", | |
" res0 = lookup_2_lanes(dic, mask0)\n", | |
" res1 = lookup_2_lanes(dic, mask1)\n", | |
" \n", | |
" # this is really clever: there are no cross-lane operations that \n", | |
" # take many cycles and they avoid masking out the upper bits \n", | |
" # by subtracting them outside the loop. \n", | |
" accu[q, 0] += res0.view('uint16')\n", | |
" accu[q, 1] += res0.view('uint16') >> 8\n", | |
" accu[q, 2] += res1.view('uint16')\n", | |
" accu[q, 3] += res1.view('uint16') >> 8\n", | |
" \n", | |
" dis = np.zeros((nq, 32), 'uint16')\n", | |
" \n", | |
" for q in range(nq): \n", | |
" accu[q, 0] -= accu[q, 1] << 8\n", | |
" accu[q, 2] -= accu[q, 3] << 8\n", | |
" \n", | |
" dis[q, :8] = accu[q, 0, :8] + accu[q, 0, 8:]\n", | |
" dis[q, 8:16] = accu[q, 1, :8] + accu[q, 1, 8:]\n", | |
" \n", | |
" dis[q, 16:24] = accu[q, 2, :8] + accu[q, 2, 8:]\n", | |
" dis[q, 24:32] = accu[q, 3, :8] + accu[q, 3, 8:]\n", | |
" \n", | |
" return dis.reshape(nq, 32)\n", | |
"\n", | |
"def loop3_accu(codes, LUT):\n", | |
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n", | |
"\n", | |
" nb_32, nsq_2, is_32 = codes.shape\n", | |
" assert is_32 == 32\n", | |
" nb = nb_32 * 32\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" q_blocks, nsq_2, qbs, is_32 = LUT.shape\n", | |
" assert nsq_2 * 2 == nsq\n", | |
" assert is_32 == 32\n", | |
" nq = q_blocks * qbs\n", | |
" \n", | |
" accu = np.zeros((nq, nb), 'uint16')\n", | |
" for i in range(0, nq, qbs): \n", | |
" for j in range(0, nb, 32): \n", | |
" block = loop3_kernel(codes[j // 32], LUT[i // qbs])\n", | |
" accu[i : i + qbs, j : j + 32] = block\n", | |
" return accu\n", | |
"\n", | |
"def loop3_prepare_LUT(LUT, qbs): \n", | |
" nq, nsq, is_16 = LUT.shape\n", | |
" assert is_16 == 16\n", | |
" assert nsq % 2 == 0\n", | |
" assert nq % qbs == 0\n", | |
" \n", | |
" LUT2 = np.zeros((nq // qbs, nsq // 2, qbs, 32), 'uint8')\n", | |
" \n", | |
" for q in range(nq): \n", | |
" for sq in range(0, nsq, 2): \n", | |
" c0 = LUT[q, sq]\n", | |
" c1 = LUT[q, sq + 1]\n", | |
" \n", | |
" LUT2[q // qbs, sq // 2, q % qbs, :16] = c0 \n", | |
" LUT2[q // qbs, sq // 2, q % qbs, 16:] = c1\n", | |
" \n", | |
" return LUT2\n", | |
"\n", | |
"\n", | |
"def loop3_prepare_codes(codes): \n", | |
" nb, nsq_2 = codes.shape\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" assert nb % 32 == 0\n", | |
" \n", | |
" codes2 = np.zeros((nb // 32, nsq_2, 32), 'uint8')\n", | |
" \n", | |
" perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n", | |
" perm1 = 16 + perm0\n", | |
" \n", | |
" for i in range(0, nb, 32): \n", | |
" for sq in range(0, nsq, 2): \n", | |
" c = codes[i : i + 32, sq // 2] \n", | |
" c0 = c & 15 # 32 codes for sq \n", | |
" c1 = c >> 4 # 32 codes for sq + 1 \n", | |
" codes2[i // 32, sq // 2, :16] = c0[perm0] | c0[perm1] << 4\n", | |
" codes2[i // 32, sq // 2, 16:] = c1[perm0] | c1[perm1] << 4\n", | |
" \n", | |
" return codes2\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.arange(16).reshape(2, 8).transpose(1, 0).ravel()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12288" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# set some block sizes\n", | |
"qbs = 3\n", | |
"codes2 = loop3_prepare_codes(codes)\n", | |
"LUT2 = loop3_prepare_LUT(LUT, qbs)\n", | |
"\n", | |
"loop3_nread = 0\n", | |
"accu_loop3 = loop3_accu(codes2, LUT2)\n", | |
"loop3_nread" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(accu_loop3 == accu_ref)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Larger kernel (loop4)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def loop4_kernel(codes, LUT): \n", | |
" \"\"\" \n", | |
" computes the distance matrix for nq (even) queries and 4 * 16 database vectors \n", | |
" \"\"\"\n", | |
" nsq, nq_2, is_32 = LUT.shape\n", | |
" assert is_32 == 32\n", | |
" nq = nq_2 * 2\n", | |
" \n", | |
" nsqB, is_32 = codes.shape\n", | |
" assert is_32 == 32\n", | |
" assert nsqB == nsq \n", | |
" \n", | |
" global loop4_nread \n", | |
" \n", | |
" dis = np.zeros((nq_2, 8, 16), 'uint16')\n", | |
" for sq in range(nsq): \n", | |
" c = codes[sq]\n", | |
" loop4_nread += 1\n", | |
" \n", | |
" # 1 vperm2f128 and 2 blend\n", | |
" c0 = np.hstack((c[:16], c[:16])) \n", | |
" m0 = c0 & 15\n", | |
" m1 = c0 >> 4\n", | |
" \n", | |
" c1 = np.hstack((c[16:], c[16:]))\n", | |
" m2 = c1 & 15\n", | |
" m3 = c1 >> 4\n", | |
" \n", | |
" for q_2 in range(nq // 2): \n", | |
" \n", | |
" dic = LUT[sq, q_2]\n", | |
" loop4_nread += 1\n", | |
"\n", | |
" res0 = lookup_2_lanes(dic, m0) \n", | |
" dis[q_2, 0] += res0.view('uint16')\n", | |
" dis[q_2, 2] += res0.view('uint16') >> 8\n", | |
" \n", | |
" res1 = lookup_2_lanes(dic, m1)\n", | |
" dis[q_2, 1] += res1.view('uint16')\n", | |
" dis[q_2, 3] += res1.view('uint16') >> 8\n", | |
"\n", | |
" res2 = lookup_2_lanes(dic, m2)\n", | |
" dis[q_2, 4] += res2.view('uint16')\n", | |
" dis[q_2, 6] += res2.view('uint16') >> 8\n", | |
"\n", | |
" res3 = lookup_2_lanes(dic, m3)\n", | |
" dis[q_2, 5] += res3.view('uint16')\n", | |
" dis[q_2, 7] += res3.view('uint16') >> 8\n", | |
" \n", | |
" for q_2 in range(nq // 2): \n", | |
" dis[q_2, 0] -= dis[q_2, 2] << 8\n", | |
" dis[q_2, 1] -= dis[q_2, 3] << 8\n", | |
" dis[q_2, 4] -= dis[q_2, 6] << 8\n", | |
" dis[q_2, 5] -= dis[q_2, 7] << 8\n", | |
" \n", | |
" dis = dis.reshape((nq_2, 8, 2, 8)).transpose(0, 2, 1, 3)\n", | |
" return dis.reshape(nq, 64)\n", | |
"\n", | |
"\n", | |
"def loop4_accu(codes, LUT):\n", | |
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n", | |
"\n", | |
" nb_64, nsq, is_32 = codes.shape\n", | |
" assert is_32 == 32\n", | |
" nb = nb_64 * 64\n", | |
" \n", | |
" q_blocks, nsqB, qbs_2, is_32 = LUT.shape\n", | |
" assert nsqB == nsq\n", | |
" assert is_32 == 32\n", | |
" qbs = qbs_2 * 2\n", | |
" nq = q_blocks * qbs\n", | |
" \n", | |
" accu = np.zeros((nq, nb), 'uint16')\n", | |
" for i in range(0, nq, qbs): \n", | |
" for j in range(0, nb, 64): \n", | |
" block = loop4_kernel(codes[j // 64], LUT[i // qbs])\n", | |
" accu[i : i + qbs, j : j + 64] = block\n", | |
" return accu\n", | |
"\n", | |
"def loop4_prepare_LUT(LUT, qbs): \n", | |
" nq, nsq, is_16 = LUT.shape\n", | |
" assert is_16 == 16\n", | |
" assert nq % qbs == 0\n", | |
" assert qbs % 2 == 0\n", | |
" qbs_2 = qbs // 2\n", | |
" \n", | |
" LUT2 = np.zeros((nq // qbs, nsq, qbs_2, 32), 'uint8')\n", | |
" assert LUT2.size == LUT.size\n", | |
" \n", | |
" for q in range(0, nq, 2):\n", | |
" q_2 = q // 2\n", | |
" for sq in range(nsq): \n", | |
" LUT2[q_2 // qbs_2, sq, q % qbs_2, :16] = LUT[q, sq, :]\n", | |
" LUT2[q_2 // qbs_2, sq, q % qbs_2, 16:] = LUT[q + 1, sq, :]\n", | |
" \n", | |
" return LUT2\n", | |
"\n", | |
"def reshape_64(a): \n", | |
" ar = a.reshape(2, 4, 8)\n", | |
" return ar.transpose(0, 2, 1).ravel()\n", | |
"\n", | |
"def pack_64(a): \n", | |
" ar = reshape_64(a)\n", | |
" return ar[::2] | ar[1::2] << 4\n", | |
"\n", | |
"def loop4_prepare_codes(codes): \n", | |
" nb, nsq_2 = codes.shape\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" assert nb % 64 == 0\n", | |
" \n", | |
" codes2 = np.zeros((nb // 64, nsq, 32), 'uint8')\n", | |
" \n", | |
" for i in range(0, nb, 64): \n", | |
" for sq in range(0, nsq, 2): \n", | |
" c = codes[i : i + 64, sq // 2] \n", | |
" c0 = c & 15 # 64 codes for sq \n", | |
" c1 = c >> 4 # 64 codes for sq + 1\n", | |
" codes2[i // 64, sq , :] = pack_64(c0)\n", | |
" codes2[i // 64, sq + 1, :] = pack_64(c1)\n", | |
" \n", | |
" return codes2\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9216" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# set some block sizes\n", | |
"qbs = 2\n", | |
"codes2 = loop4_prepare_codes(codes)\n", | |
"LUT2 = loop4_prepare_LUT(LUT, qbs)\n", | |
"\n", | |
"loop4_nread = 0\n", | |
"accu_loop4 = loop4_accu(codes2, LUT2)\n", | |
"loop4_nread" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(accu_loop4 == accu_ref)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Loop5 (better for nq=1 and nq=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loop5_nread = 0\n", | |
"\n", | |
"\n", | |
"def loop5_kernel(codes, LUT): \n", | |
"\n", | |
" nsq_2, nq, is_32 = LUT.shape\n", | |
" assert is_32 == 32\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" nsq_2, nb_32, is_32 = codes.shape\n", | |
" assert nsq_2 == nsq // 2\n", | |
" nb = nb_32 * 32\n", | |
" \n", | |
" global loop5_nread \n", | |
" \n", | |
" accu = np.zeros((nq, nb // 32, 4, 16), 'uint16')\n", | |
" for sq in range(0, nsq, 2):\n", | |
" b_cache = codes[sq // 2, :]\n", | |
" loop5_nread += len(b_cache)\n", | |
" for q in range(nq): \n", | |
" dic = LUT[sq // 2, q]\n", | |
" loop5_nread += 1\n", | |
" for b in range(0, nb, 32): \n", | |
" c = b_cache[b // 32]\n", | |
" mask0 = c & 15\n", | |
" mask1 = c >> 4\n", | |
" loop5_nread += 1\n", | |
" res0 = lookup_2_lanes(dic, mask0)\n", | |
" res1 = lookup_2_lanes(dic, mask1)\n", | |
"\n", | |
" accu[q, b // 32, 0] += res0.view('uint16')\n", | |
" accu[q, b // 32, 1] += res0.view('uint16') >> 8\n", | |
" accu[q, b // 32, 2] += res1.view('uint16')\n", | |
" accu[q, b // 32, 3] += res1.view('uint16') >> 8\n", | |
" \n", | |
" dis = np.zeros((nq, nb), 'uint16')\n", | |
" \n", | |
" for q in range(nq): \n", | |
" for b in range(nb // 32):\n", | |
" b0 = b * 32\n", | |
" accu[q, b, 0] -= accu[q, b, 1] << 8\n", | |
" # combine2x2\n", | |
" dis[q, b0 : b0 + 8 ] = accu[q, b, 0, :8] + accu[q, b, 0, 8:]\n", | |
" dis[q, b0 + 8: b0 + 16] = accu[q, b, 1, :8] + accu[q, b, 1, 8:]\n", | |
"\n", | |
" b0 += 16\n", | |
" accu[q, b, 2] -= accu[q, b, 3] << 8\n", | |
" # combine2x2\n", | |
" dis[q, b0 : b0 + 8 ] = accu[q, b, 2, :8] + accu[q, b, 2, 8:]\n", | |
" dis[q, b0 + 8: b0 + 16] = accu[q, b, 3, :8] + accu[q, b, 3, 8:]\n", | |
" \n", | |
" return dis\n", | |
"\n", | |
"def loop5_accu(codes, LUT):\n", | |
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n", | |
"\n", | |
" b_blocks, nsq_2, bbs_32, is_32 = codes.shape\n", | |
" assert is_32 == 32\n", | |
" nsq = nsq_2 * 2\n", | |
" bbs = bbs_32 * 32\n", | |
" nb = b_blocks * bbs\n", | |
" \n", | |
" q_blocks, nsq_2, qbs, is_32 = LUT.shape\n", | |
" assert nsq_2 * 2 == nsq\n", | |
" assert is_32 == 32\n", | |
" nq = q_blocks * qbs\n", | |
" \n", | |
" accu = np.zeros((nq, nb), 'uint16')\n", | |
" for i in range(0, nq, qbs): \n", | |
" for j in range(0, nb, bbs): \n", | |
" block = loop5_kernel(codes[j // bbs], LUT[i // qbs])\n", | |
" accu[i : i + qbs, j : j + bbs] = block\n", | |
" return accu\n", | |
"\n", | |
"def loop5_prepare_LUT(LUT, qbs): \n", | |
" nq, nsq, is_16 = LUT.shape\n", | |
" assert is_16 == 16\n", | |
" assert nsq % 2 == 0\n", | |
" assert nq % qbs == 0\n", | |
" \n", | |
" LUT2 = np.zeros((nq // qbs, nsq // 2, qbs, 32), 'uint8')\n", | |
" \n", | |
" for q in range(nq): \n", | |
" for sq in range(0, nsq, 2): \n", | |
" c0 = LUT[q, sq]\n", | |
" c1 = LUT[q, sq + 1]\n", | |
" \n", | |
" LUT2[q // qbs, sq // 2, q % qbs, :16] = c0 \n", | |
" LUT2[q // qbs, sq // 2, q % qbs, 16:] = c1\n", | |
" \n", | |
" return LUT2\n", | |
"\n", | |
"\n", | |
"def loop5_prepare_codes(codes, bbs): \n", | |
" nb, nsq_2 = codes.shape\n", | |
" nsq = nsq_2 * 2\n", | |
" \n", | |
" assert nb % bbs == 0\n", | |
" assert bbs % 32 == 0\n", | |
" \n", | |
" codes2 = np.zeros((nb // bbs, nsq_2, bbs // 32, 32), 'uint8')\n", | |
" \n", | |
" perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n", | |
" perm1 = 16 + perm0\n", | |
" \n", | |
" for i0 in range(0, nb, bbs): \n", | |
" for sq in range(0, nsq, 2):\n", | |
" for i in range(0, bbs, 32): \n", | |
" c = codes[i0 + i : i0 + i + 32, sq // 2] \n", | |
" c0 = c & 15 # 32 codes for sq \n", | |
" c1 = c >> 4 # 32 codes for sq + 1 \n", | |
" codes2[i0 // bbs, sq // 2, i // 32, :16] = c0[perm0] | c0[perm1] << 4\n", | |
" codes2[i0 // bbs, sq // 2, i // 32, 16:] = c1[perm0] | c1[perm1] << 4\n", | |
" \n", | |
" return codes2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"16896" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# set some block sizes\n", | |
"qbs = 2\n", | |
"bbs = 3*32\n", | |
"codes2 = loop5_prepare_codes(codes, bbs)\n", | |
"LUT2 = loop5_prepare_LUT(LUT, qbs)\n", | |
"\n", | |
"loop5_nread = 0\n", | |
"accu_loop5 = loop5_accu(codes2, LUT2)\n", | |
"loop5_nread\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(accu_loop5 == accu_ref)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# element-wise access\n", | |
"\n", | |
"perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n", | |
"iperm0 = np.zeros_like(perm0)\n", | |
"iperm0[perm0] = np.arange(16)\n", | |
"# perm0, iperm0, perm0[iperm0]\n", | |
"\n", | |
"def loop5_get_element(codes, i, sq): \n", | |
" b_blocks, nsq_2, bbs_32, is_32 = codes.shape\n", | |
" assert is_32 == 32\n", | |
" nsq = nsq_2 * 2\n", | |
" bbs = bbs_32 * 32\n", | |
" nb = b_blocks * bbs\n", | |
"\n", | |
" assert i < nb\n", | |
" \n", | |
" c = codes[i // bbs, sq // 2]\n", | |
" i = i % bbs\n", | |
" sq = sq % 2 \n", | |
" c = c[i // 32]\n", | |
" i = i % 32\n", | |
" if sq == 0: \n", | |
" c = c[:16]\n", | |
" else: \n", | |
" c = c[16:]\n", | |
" if i < 16: \n", | |
" return c[iperm0[i]] & 15\n", | |
" else: \n", | |
" return c[iperm0[i - 16]] >> 4\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 104, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rs = np.random.RandomState(123)\n", | |
"for run in range(1000): \n", | |
" i = rs.randint(nb)\n", | |
" sq = rs.randint(nsq)\n", | |
" c = (codes[i, sq // 2] >> (4 * (sq %2))) & 15\n", | |
" c2 = loop5_get_element(codes2, i, sq)\n", | |
" assert c == c2\n", | |
" # print(f\"REF({i}, {sq}) = {c} ?= {c2}\")\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(3, 32)" | |
] | |
}, | |
"execution_count": 105, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"codes2[0, 0].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15])" | |
] | |
}, | |
"execution_count": 106, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iperm0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment