Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created January 11, 2022 18:49
Show Gist options
  • Save mdouze/5c32300cf3bd20946a7762f6c016e823 to your computer and use it in GitHub Desktop.
Save mdouze/5c32300cf3bd20946a7762f6c016e823 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def reference_accu(codes, LUT): \n",
" \"\"\"\n",
" Reference scalar accumulation function for PQ4 data. \n",
" The codes and LUT are laid out in a \"natural\" way.\n",
" nq = nb of queries\n",
" nb = nb of database vectors\n",
" nsq = nb of sub-quantizers (each size 4 bits)\n",
" \"\"\"\n",
" \n",
" nq, nsq, is_16 = LUT.shape\n",
" nb, nsq_2 = codes.shape\n",
" assert is_16 == 16\n",
" assert nsq_2 == nsq // 2 \n",
" accu = np.zeros((nq, nb), 'uint16')\n",
" for i in range(nq): \n",
" for j in range(nb): \n",
" a = np.uint16(0)\n",
" for sq in range(0, nsq, 2): \n",
" c = codes[j, sq // 2]\n",
" a += LUT[i, sq , c & 15].astype('uint16')\n",
" a += LUT[i, sq + 1, c >> 4].astype('uint16')\n",
" accu[i, j] = a\n",
" return accu"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\"\"\" \n",
"Generate random data for testing\n",
"\"\"\"\n",
"\n",
"nq = 96 # number of queries\n",
"nb = 384 # databse size \n",
"nsq = 16 # number of sub-quantizers\n",
"\n",
"rs = np.random.RandomState(123)\n",
"codes = rs.randint(256, size=(nb, nsq // 2)).astype('uint8')\n",
"LUT = rs.randint(256, size=(nq, nsq, 16)).astype('uint8')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"accu_ref = reference_accu(codes, LUT)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# First attempt -- not very efficient (loop1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def lookup_2_lanes(a, b):\n",
" \"\"\" The AVX2 lookup instruction, it is very crippled (because it operates on 2 * 16 bytes rather than on 32)\n",
" but it's the only efficient way to do lookups \"\"\"\n",
" blane0 = b[:16] & 15\n",
" blane1 = b[16:] & 15 \n",
" return np.hstack((a[blane0], a[blane1 + 16]))\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.arange(16).reshape(2, 8).T.ravel()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def loop1_prepare_codes(codes, bbs): \n",
" \"\"\" \n",
" re-organize the code matrix so that the loop1 kernel reads all data sequentially\n",
" bbs = database side block size (must be a multiple of 16)\n",
" \"\"\"\n",
" nb, nsq_2 = codes.shape\n",
" nsq = nsq_2 * 2\n",
" assert nb % bbs == 0\n",
" assert nsq % 4 == 0\n",
" assert bbs % 16 == 0\n",
" \n",
" codes2 = np.zeros((nb // bbs, nsq // 4, bbs // 16, 32), 'uint8')\n",
" for i in range(0, nb, 16): \n",
" for sq in range(0, nsq, 4): \n",
" c01 = codes[i : i + 16, sq // 2]\n",
" c23 = codes[i : i + 16, sq // 2 + 1]\n",
" c02 = c01 & 15 | c23 << 4\n",
" c13 = c01 >> 4 | c23 & 0xF0 \n",
" cblock = codes2[i // bbs, sq // 4, (i % bbs) // 16, :] \n",
" cblock[:16] = c02\n",
" cblock[16:] = c13\n",
" return codes2\n",
" \n",
" \n",
"def loop1_prepare_LUT(LUT, qbs): \n",
" \"\"\" \n",
" Re-organize the LUT matrix so that the loop1 kernel reads data sequentially\n",
" \"\"\"\n",
" nq, nsq, is_16 = LUT.shape\n",
" assert is_16 == 16\n",
" assert nsq % 4 == 0\n",
" assert nq % qbs == 0\n",
" \n",
" LUT2 = np.zeros((nq // qbs, nsq // 4, qbs, 2, 32), 'uint8')\n",
" \n",
" for i in range(nq): \n",
" for sq in range(0, nsq, 4): \n",
" LUT2[i // qbs, sq // 4, i % qbs, 0, :16] = LUT[i, sq , :]\n",
" LUT2[i // qbs, sq // 4, i % qbs, 0, 16:] = LUT[i, sq + 1, :]\n",
" LUT2[i // qbs, sq // 4, i % qbs, 1, :16] = LUT[i, sq + 2, :]\n",
" LUT2[i // qbs, sq // 4, i % qbs, 1, 16:] = LUT[i, sq + 3, :]\n",
" \n",
" return LUT2\n",
"\n",
"loop1_nread = 0\n",
"\n",
"\n",
"def combine2x2(a, b): \n",
" return np.hstack((a[:8] + a[8:], b[:8] + b[8:]))\n",
" \n",
"def loop1_kernel(codes, LUT): \n",
" \"\"\" computation kernel for 1 block of the output distance matrix. \n",
" The block is of size bbs * qbs (determined from the sizes of the codes and LUT table)\n",
" \"\"\"\n",
" global loop1_nread\n",
" bb = bbs // 16\n",
" distances = np.zeros((qbs, bb, 16), 'uint16')\n",
" \n",
" for sq in range(0, nsq, 4): \n",
" ctab = codes[sq // 4, :, :] # read size (bb, 32)\n",
" loop1_nread += bb\n",
" for q in range(qbs): \n",
" lut0 = LUT[sq // 4, q, 0, :] # read size 32\n",
" lut1 = LUT[sq // 4, q, 1, :] # read size 32\n",
" loop1_nread += 2\n",
" for b in range(bb): \n",
" c = ctab[b]\n",
" chi = c >> 4\n",
" clo = c & 15\n",
" q02 = lookup_2_lanes(lut0, clo)\n",
" q13 = lookup_2_lanes(lut1, chi)\n",
" \n",
" q02 = q02.view('uint16')\n",
" q13 = q13.view('uint16')\n",
" \n",
" sev = q02 + q13 \n",
" sod = (q02 >> 8) + (q13 >> 8)\n",
" \n",
" distances[q, b, :] += combine2x2(sev, sod)\n",
" \n",
" distances[:, :, :8] -= distances[:, :, 8:] << 8\n",
" \n",
" perm = np.arange(16).reshape(2, 8).T.ravel() \n",
" \n",
" distances = distances[:, :, perm]\n",
" \n",
" return distances.reshape(qbs, bbs)\n",
"\n",
"\n",
"def loop1_accu(codes, LUT):\n",
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n",
" q_blocks, nsq_4, qbs, is_2, is_32 = LUT.shape \n",
" nsq = nsq_4 * 4\n",
" nq = q_blocks * qbs\n",
" assert is_2 == 2\n",
" assert is_32 == 32\n",
" \n",
" b_blocks, nsq_4, bb, is_32 = codes.shape\n",
" assert nsq == nsq_4 * 4\n",
" bbs = bb * 16\n",
" assert is_32 == 32\n",
" nb = b_blocks * bbs\n",
" \n",
" accu = np.zeros((nq, nb), 'uint16')\n",
" for i in range(0, nq, qbs): \n",
" for j in range(0, nb, bbs): \n",
" block = loop1_kernel(codes[j // bbs], LUT[i // qbs])\n",
" accu[i : i + qbs, j : j + bbs] = block\n",
" return accu\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8448"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set some block sizes\n",
"bbs = 48\n",
"qbs = 4\n",
"codes2 = loop1_prepare_codes(codes, bbs)\n",
"LUT2 = loop1_prepare_LUT(LUT, qbs)\n",
"\n",
"loop1_nread = 0\n",
"accu_loop1 = loop1_accu(codes2, LUT2)\n",
"loop1_nread "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(accu_loop1 == accu_ref)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SCANN implementation (loop3)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"loop3_nread = 0\n",
"\n",
"def loop3_kernel(codes, LUT): \n",
" \"\"\" \n",
" computes the distance matrix for nq queries and 2 * 16 database vectors \n",
" \"\"\"\n",
" nsq_2, nq, is_32 = LUT.shape\n",
" assert is_32 == 32\n",
" nsq = nsq_2 * 2\n",
" \n",
" nsq_2, is_32 = codes.shape\n",
" assert nsq_2 == nsq // 2\n",
" \n",
" global loop3_nread \n",
" \n",
" accu = np.zeros((nq, 4, 16), 'uint16')\n",
" for sq in range(0, nsq, 2): \n",
" c = codes[sq // 2]\n",
" mask0 = c & 15\n",
" mask1 = c >> 4\n",
" loop3_nread += 1\n",
" for q in range(nq): \n",
" dic = LUT[sq // 2, q]\n",
" loop3_nread += 1\n",
" res0 = lookup_2_lanes(dic, mask0)\n",
" res1 = lookup_2_lanes(dic, mask1)\n",
" \n",
" # this is really clever: there are no cross-lane operations that \n",
" # take many cycles and they avoid masking out the upper bits \n",
" # by subtracting them outside the loop. \n",
" accu[q, 0] += res0.view('uint16')\n",
" accu[q, 1] += res0.view('uint16') >> 8\n",
" accu[q, 2] += res1.view('uint16')\n",
" accu[q, 3] += res1.view('uint16') >> 8\n",
" \n",
" dis = np.zeros((nq, 32), 'uint16')\n",
" \n",
" for q in range(nq): \n",
" accu[q, 0] -= accu[q, 1] << 8\n",
" accu[q, 2] -= accu[q, 3] << 8\n",
" \n",
" dis[q, :8] = accu[q, 0, :8] + accu[q, 0, 8:]\n",
" dis[q, 8:16] = accu[q, 1, :8] + accu[q, 1, 8:]\n",
" \n",
" dis[q, 16:24] = accu[q, 2, :8] + accu[q, 2, 8:]\n",
" dis[q, 24:32] = accu[q, 3, :8] + accu[q, 3, 8:]\n",
" \n",
" return dis.reshape(nq, 32)\n",
"\n",
"def loop3_accu(codes, LUT):\n",
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n",
"\n",
" nb_32, nsq_2, is_32 = codes.shape\n",
" assert is_32 == 32\n",
" nb = nb_32 * 32\n",
" nsq = nsq_2 * 2\n",
" \n",
" q_blocks, nsq_2, qbs, is_32 = LUT.shape\n",
" assert nsq_2 * 2 == nsq\n",
" assert is_32 == 32\n",
" nq = q_blocks * qbs\n",
" \n",
" accu = np.zeros((nq, nb), 'uint16')\n",
" for i in range(0, nq, qbs): \n",
" for j in range(0, nb, 32): \n",
" block = loop3_kernel(codes[j // 32], LUT[i // qbs])\n",
" accu[i : i + qbs, j : j + 32] = block\n",
" return accu\n",
"\n",
"def loop3_prepare_LUT(LUT, qbs): \n",
" nq, nsq, is_16 = LUT.shape\n",
" assert is_16 == 16\n",
" assert nsq % 2 == 0\n",
" assert nq % qbs == 0\n",
" \n",
" LUT2 = np.zeros((nq // qbs, nsq // 2, qbs, 32), 'uint8')\n",
" \n",
" for q in range(nq): \n",
" for sq in range(0, nsq, 2): \n",
" c0 = LUT[q, sq]\n",
" c1 = LUT[q, sq + 1]\n",
" \n",
" LUT2[q // qbs, sq // 2, q % qbs, :16] = c0 \n",
" LUT2[q // qbs, sq // 2, q % qbs, 16:] = c1\n",
" \n",
" return LUT2\n",
"\n",
"\n",
"def loop3_prepare_codes(codes): \n",
" nb, nsq_2 = codes.shape\n",
" nsq = nsq_2 * 2\n",
" \n",
" assert nb % 32 == 0\n",
" \n",
" codes2 = np.zeros((nb // 32, nsq_2, 32), 'uint8')\n",
" \n",
" perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n",
" perm1 = 16 + perm0\n",
" \n",
" for i in range(0, nb, 32): \n",
" for sq in range(0, nsq, 2): \n",
" c = codes[i : i + 32, sq // 2] \n",
" c0 = c & 15 # 32 codes for sq \n",
" c1 = c >> 4 # 32 codes for sq + 1 \n",
" codes2[i // 32, sq // 2, :16] = c0[perm0] | c0[perm1] << 4\n",
" codes2[i // 32, sq // 2, 16:] = c1[perm0] | c1[perm1] << 4\n",
" \n",
" return codes2\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.arange(16).reshape(2, 8).transpose(1, 0).ravel()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12288"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set some block sizes\n",
"qbs = 3\n",
"codes2 = loop3_prepare_codes(codes)\n",
"LUT2 = loop3_prepare_LUT(LUT, qbs)\n",
"\n",
"loop3_nread = 0\n",
"accu_loop3 = loop3_accu(codes2, LUT2)\n",
"loop3_nread"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(accu_loop3 == accu_ref)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Larger kernel (loop4)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def loop4_kernel(codes, LUT): \n",
" \"\"\" \n",
" computes the distance matrix for nq (even) queries and 4 * 16 database vectors \n",
" \"\"\"\n",
" nsq, nq_2, is_32 = LUT.shape\n",
" assert is_32 == 32\n",
" nq = nq_2 * 2\n",
" \n",
" nsqB, is_32 = codes.shape\n",
" assert is_32 == 32\n",
" assert nsqB == nsq \n",
" \n",
" global loop4_nread \n",
" \n",
" dis = np.zeros((nq_2, 8, 16), 'uint16')\n",
" for sq in range(nsq): \n",
" c = codes[sq]\n",
" loop4_nread += 1\n",
" \n",
" # 1 vperm2f128 and 2 blend\n",
" c0 = np.hstack((c[:16], c[:16])) \n",
" m0 = c0 & 15\n",
" m1 = c0 >> 4\n",
" \n",
" c1 = np.hstack((c[16:], c[16:]))\n",
" m2 = c1 & 15\n",
" m3 = c1 >> 4\n",
" \n",
" for q_2 in range(nq // 2): \n",
" \n",
" dic = LUT[sq, q_2]\n",
" loop4_nread += 1\n",
"\n",
" res0 = lookup_2_lanes(dic, m0) \n",
" dis[q_2, 0] += res0.view('uint16')\n",
" dis[q_2, 2] += res0.view('uint16') >> 8\n",
" \n",
" res1 = lookup_2_lanes(dic, m1)\n",
" dis[q_2, 1] += res1.view('uint16')\n",
" dis[q_2, 3] += res1.view('uint16') >> 8\n",
"\n",
" res2 = lookup_2_lanes(dic, m2)\n",
" dis[q_2, 4] += res2.view('uint16')\n",
" dis[q_2, 6] += res2.view('uint16') >> 8\n",
"\n",
" res3 = lookup_2_lanes(dic, m3)\n",
" dis[q_2, 5] += res3.view('uint16')\n",
" dis[q_2, 7] += res3.view('uint16') >> 8\n",
" \n",
" for q_2 in range(nq // 2): \n",
" dis[q_2, 0] -= dis[q_2, 2] << 8\n",
" dis[q_2, 1] -= dis[q_2, 3] << 8\n",
" dis[q_2, 4] -= dis[q_2, 6] << 8\n",
" dis[q_2, 5] -= dis[q_2, 7] << 8\n",
" \n",
" dis = dis.reshape((nq_2, 8, 2, 8)).transpose(0, 2, 1, 3)\n",
" return dis.reshape(nq, 64)\n",
"\n",
"\n",
"def loop4_accu(codes, LUT):\n",
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n",
"\n",
" nb_64, nsq, is_32 = codes.shape\n",
" assert is_32 == 32\n",
" nb = nb_64 * 64\n",
" \n",
" q_blocks, nsqB, qbs_2, is_32 = LUT.shape\n",
" assert nsqB == nsq\n",
" assert is_32 == 32\n",
" qbs = qbs_2 * 2\n",
" nq = q_blocks * qbs\n",
" \n",
" accu = np.zeros((nq, nb), 'uint16')\n",
" for i in range(0, nq, qbs): \n",
" for j in range(0, nb, 64): \n",
" block = loop4_kernel(codes[j // 64], LUT[i // qbs])\n",
" accu[i : i + qbs, j : j + 64] = block\n",
" return accu\n",
"\n",
"def loop4_prepare_LUT(LUT, qbs): \n",
" nq, nsq, is_16 = LUT.shape\n",
" assert is_16 == 16\n",
" assert nq % qbs == 0\n",
" assert qbs % 2 == 0\n",
" qbs_2 = qbs // 2\n",
" \n",
" LUT2 = np.zeros((nq // qbs, nsq, qbs_2, 32), 'uint8')\n",
" assert LUT2.size == LUT.size\n",
" \n",
" for q in range(0, nq, 2):\n",
" q_2 = q // 2\n",
" for sq in range(nsq): \n",
" LUT2[q_2 // qbs_2, sq, q % qbs_2, :16] = LUT[q, sq, :]\n",
" LUT2[q_2 // qbs_2, sq, q % qbs_2, 16:] = LUT[q + 1, sq, :]\n",
" \n",
" return LUT2\n",
"\n",
"def reshape_64(a): \n",
" ar = a.reshape(2, 4, 8)\n",
" return ar.transpose(0, 2, 1).ravel()\n",
"\n",
"def pack_64(a): \n",
" ar = reshape_64(a)\n",
" return ar[::2] | ar[1::2] << 4\n",
"\n",
"def loop4_prepare_codes(codes): \n",
" nb, nsq_2 = codes.shape\n",
" nsq = nsq_2 * 2\n",
" \n",
" assert nb % 64 == 0\n",
" \n",
" codes2 = np.zeros((nb // 64, nsq, 32), 'uint8')\n",
" \n",
" for i in range(0, nb, 64): \n",
" for sq in range(0, nsq, 2): \n",
" c = codes[i : i + 64, sq // 2] \n",
" c0 = c & 15 # 64 codes for sq \n",
" c1 = c >> 4 # 64 codes for sq + 1\n",
" codes2[i // 64, sq , :] = pack_64(c0)\n",
" codes2[i // 64, sq + 1, :] = pack_64(c1)\n",
" \n",
" return codes2\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9216"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set some block sizes\n",
"qbs = 2\n",
"codes2 = loop4_prepare_codes(codes)\n",
"LUT2 = loop4_prepare_LUT(LUT, qbs)\n",
"\n",
"loop4_nread = 0\n",
"accu_loop4 = loop4_accu(codes2, LUT2)\n",
"loop4_nread"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(accu_loop4 == accu_ref)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loop5 (better for nq=1 and nq=2)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"loop5_nread = 0\n",
"\n",
"\n",
"def loop5_kernel(codes, LUT): \n",
"\n",
" nsq_2, nq, is_32 = LUT.shape\n",
" assert is_32 == 32\n",
" nsq = nsq_2 * 2\n",
" \n",
" nsq_2, nb_32, is_32 = codes.shape\n",
" assert nsq_2 == nsq // 2\n",
" nb = nb_32 * 32\n",
" \n",
" global loop5_nread \n",
" \n",
" accu = np.zeros((nq, nb // 32, 4, 16), 'uint16')\n",
" for sq in range(0, nsq, 2):\n",
" b_cache = codes[sq // 2, :]\n",
" loop5_nread += len(b_cache)\n",
" for q in range(nq): \n",
" dic = LUT[sq // 2, q]\n",
" loop5_nread += 1\n",
" for b in range(0, nb, 32): \n",
" c = b_cache[b // 32]\n",
" mask0 = c & 15\n",
" mask1 = c >> 4\n",
" loop5_nread += 1\n",
" res0 = lookup_2_lanes(dic, mask0)\n",
" res1 = lookup_2_lanes(dic, mask1)\n",
"\n",
" accu[q, b // 32, 0] += res0.view('uint16')\n",
" accu[q, b // 32, 1] += res0.view('uint16') >> 8\n",
" accu[q, b // 32, 2] += res1.view('uint16')\n",
" accu[q, b // 32, 3] += res1.view('uint16') >> 8\n",
" \n",
" dis = np.zeros((nq, nb), 'uint16')\n",
" \n",
" for q in range(nq): \n",
" for b in range(nb // 32):\n",
" b0 = b * 32\n",
" accu[q, b, 0] -= accu[q, b, 1] << 8\n",
" # combine2x2\n",
" dis[q, b0 : b0 + 8 ] = accu[q, b, 0, :8] + accu[q, b, 0, 8:]\n",
" dis[q, b0 + 8: b0 + 16] = accu[q, b, 1, :8] + accu[q, b, 1, 8:]\n",
"\n",
" b0 += 16\n",
" accu[q, b, 2] -= accu[q, b, 3] << 8\n",
" # combine2x2\n",
" dis[q, b0 : b0 + 8 ] = accu[q, b, 2, :8] + accu[q, b, 2, 8:]\n",
" dis[q, b0 + 8: b0 + 16] = accu[q, b, 3, :8] + accu[q, b, 3, 8:]\n",
" \n",
" return dis\n",
"\n",
"def loop5_accu(codes, LUT):\n",
" \"\"\" runs the kernel to produce the complete distance matrix \"\"\"\n",
"\n",
" b_blocks, nsq_2, bbs_32, is_32 = codes.shape\n",
" assert is_32 == 32\n",
" nsq = nsq_2 * 2\n",
" bbs = bbs_32 * 32\n",
" nb = b_blocks * bbs\n",
" \n",
" q_blocks, nsq_2, qbs, is_32 = LUT.shape\n",
" assert nsq_2 * 2 == nsq\n",
" assert is_32 == 32\n",
" nq = q_blocks * qbs\n",
" \n",
" accu = np.zeros((nq, nb), 'uint16')\n",
" for i in range(0, nq, qbs): \n",
" for j in range(0, nb, bbs): \n",
" block = loop5_kernel(codes[j // bbs], LUT[i // qbs])\n",
" accu[i : i + qbs, j : j + bbs] = block\n",
" return accu\n",
"\n",
"def loop5_prepare_LUT(LUT, qbs): \n",
" nq, nsq, is_16 = LUT.shape\n",
" assert is_16 == 16\n",
" assert nsq % 2 == 0\n",
" assert nq % qbs == 0\n",
" \n",
" LUT2 = np.zeros((nq // qbs, nsq // 2, qbs, 32), 'uint8')\n",
" \n",
" for q in range(nq): \n",
" for sq in range(0, nsq, 2): \n",
" c0 = LUT[q, sq]\n",
" c1 = LUT[q, sq + 1]\n",
" \n",
" LUT2[q // qbs, sq // 2, q % qbs, :16] = c0 \n",
" LUT2[q // qbs, sq // 2, q % qbs, 16:] = c1\n",
" \n",
" return LUT2\n",
"\n",
"\n",
"def loop5_prepare_codes(codes, bbs): \n",
" nb, nsq_2 = codes.shape\n",
" nsq = nsq_2 * 2\n",
" \n",
" assert nb % bbs == 0\n",
" assert bbs % 32 == 0\n",
" \n",
" codes2 = np.zeros((nb // bbs, nsq_2, bbs // 32, 32), 'uint8')\n",
" \n",
" perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n",
" perm1 = 16 + perm0\n",
" \n",
" for i0 in range(0, nb, bbs): \n",
" for sq in range(0, nsq, 2):\n",
" for i in range(0, bbs, 32): \n",
" c = codes[i0 + i : i0 + i + 32, sq // 2] \n",
" c0 = c & 15 # 32 codes for sq \n",
" c1 = c >> 4 # 32 codes for sq + 1 \n",
" codes2[i0 // bbs, sq // 2, i // 32, :16] = c0[perm0] | c0[perm1] << 4\n",
" codes2[i0 // bbs, sq // 2, i // 32, 16:] = c1[perm0] | c1[perm1] << 4\n",
" \n",
" return codes2"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16896"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set some block sizes\n",
"qbs = 2\n",
"bbs = 3*32\n",
"codes2 = loop5_prepare_codes(codes, bbs)\n",
"LUT2 = loop5_prepare_LUT(LUT, qbs)\n",
"\n",
"loop5_nread = 0\n",
"accu_loop5 = loop5_accu(codes2, LUT2)\n",
"loop5_nread\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.all(accu_loop5 == accu_ref)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
"# element-wise access\n",
"\n",
"perm0 = np.arange(16).reshape(2, 8).transpose(1, 0).ravel()\n",
"iperm0 = np.zeros_like(perm0)\n",
"iperm0[perm0] = np.arange(16)\n",
"# perm0, iperm0, perm0[iperm0]\n",
"\n",
"def loop5_get_element(codes, i, sq): \n",
" b_blocks, nsq_2, bbs_32, is_32 = codes.shape\n",
" assert is_32 == 32\n",
" nsq = nsq_2 * 2\n",
" bbs = bbs_32 * 32\n",
" nb = b_blocks * bbs\n",
"\n",
" assert i < nb\n",
" \n",
" c = codes[i // bbs, sq // 2]\n",
" i = i % bbs\n",
" sq = sq % 2 \n",
" c = c[i // 32]\n",
" i = i % 32\n",
" if sq == 0: \n",
" c = c[:16]\n",
" else: \n",
" c = c[16:]\n",
" if i < 16: \n",
" return c[iperm0[i]] & 15\n",
" else: \n",
" return c[iperm0[i - 16]] >> 4\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"rs = np.random.RandomState(123)\n",
"for run in range(1000): \n",
" i = rs.randint(nb)\n",
" sq = rs.randint(nsq)\n",
" c = (codes[i, sq // 2] >> (4 * (sq %2))) & 15\n",
" c2 = loop5_get_element(codes2, i, sq)\n",
" assert c == c2\n",
" # print(f\"REF({i}, {sq}) = {c} ?= {c2}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 32)"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"codes2[0, 0].shape"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15])"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iperm0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment