backups for PR19736 of topk() performance optimization on CPU.
Suppose input tensor has shape of [N, C]
, performance input.topk(K, sorted=Sorted)
for the followings scenarios:
- C = 10000, 40000, 320000
- K = 10, 50, 100, C/10, C/2, C-5
- Test with 20 threads and 1 thread
- Test with Sorted=True and Sorted=False
- run_topk_scale.sh
CORES=`lscpu | grep Core | awk '{print $4}'`
SOCKETS=`lscpu | grep Socket | awk '{print $2}'`
TOTAL_CORES=`expr $CORES \* $SOCKETS`
LAST_CORE=`expr $CORES - 1`
KMP_SETTING="KMP_AFFINITY=granularity=fine,compact,1,0"
KMP_BLOCKTIME=1
PREFIX="numactl --physcpubind=0-$LAST_CORE --membind=0"
export $KMP_SETTING
export KMP_BLOCKTIME=$KMP_BLOCKTIME
echo -e "\n### using $KMP_SETTING"
echo -e "### using KMP_BLOCKTIME=$KMP_BLOCKTIME"
echo -e "### using $PREFIX\n"
### single socket test
echo -e "\n### using OMP_NUM_THREADS=$CORES"
OMP_NUM_THREADS=$CORES $PREFIX python -u test_topk.py
### single thread test
echo -e "\n### using OMP_NUM_THREADS=1"
OMP_NUM_THREADS=1 $PREFIX python -u test_topk.py
- test_topk.py
import torch
from time import time
def bench_topk(N=8, C=168560, K=10, Sorted=True, num_iters=1000):
a = torch.randn(N, C)
for i in range(int(num_iters/10)):
torch.topk(a, K)
t = 0
for i in range(num_iters):
a = torch.randn(N, C)
start = time()
value, indice = torch.topk(a, K, sorted=Sorted)
t += time() - start
print("#[%d, %d], k=%d, sorted=%s times: %f ms" % (N, C, K,
('True' if Sorted else 'False'), t / num_iters * 1000))
def benc_topk_scale(Sorted):
Ns = [10]
Cs = [10000, 40000, 320000]
for n in Ns:
for c in Cs:
for k in [10, 50, 100, int(c/10), int(c/2), int(c-5)]:
iters = 500 if k > 5000 else 1000
bench_topk(n, c, k, Sorted, iters)
bench_topk_scale(True)
bench_topk_scale(False)
Tested on Intel Xeon 6148, 20x2 cores @ 2.5GHz, to reproduce results:
./run_topk_scale.sh
All numbers are reported in ms
, oob
refers original topk
performance
Table-1: OMP=20, Sorted=True
Input Size | oob | this pr | speed up |
---|---|---|---|
#[10, 10000], k=10 | 0.745 | 0.061 | 12.13 |
#[10, 10000], k=50 | 0.765 | 0.077 | 9.93 |
#[10, 10000], k=100 | 0.789 | 0.096 | 8.25 |
#[10, 10000], k=1000 | 1.252 | 0.219 | 5.72 |
#[10, 10000], k=5000 | 3.552 | 0.504 | 7.05 |
#[10, 10000], k=9995 | 6.185 | 0.814 | 7.60 |
#[10, 40000], k=10 | 2.882 | 0.176 | 16.37 |
#[10, 40000], k=50 | 2.895 | 0.193 | 14.98 |
#[10, 40000], k=100 | 2.914 | 0.222 | 13.10 |
#[10, 40000], k=4000 | 5.213 | 0.835 | 6.24 |
#[10, 40000], k=20000 | 15.811 | 2.019 | 7.83 |
#[10, 40000], k=39995 | 27.980 | 3.231 | 8.66 |
#[10, 320000], k=10 | 22.928 | 2.492 | 9.20 |
#[10, 320000], k=50 | 22.835 | 2.498 | 9.14 |
#[10, 320000], k=100 | 22.859 | 2.508 | 9.11 |
#[10, 320000], k=32000 | 45.197 | 7.523 | 6.01 |
#[10, 320000], k=160000 | 146.211 | 17.432 | 8.39 |
#[10, 320000], k=319995 | 263.868 | 29.179 | 9.04 |
Table-2: OMP=20, Sorted=False
Input Size | oob | this pr | speed up |
---|---|---|---|
#[10, 10000], k=10 | 0.746 | 0.061 | 12.20 |
#[10, 10000], k=50 | 0.752 | 0.077 | 9.74 |
#[10, 10000], k=100 | 0.756 | 0.096 | 7.88 |
#[10, 10000], k=1000 | 0.847 | 0.172 | 4.93 |
#[10, 10000], k=5000 | 1.038 | 0.186 | 5.58 |
#[10, 10000], k=9995 | 0.848 | 0.171 | 4.95 |
#[10, 40000], k=10 | 2.841 | 0.177 | 16.06 |
#[10, 40000], k=50 | 2.866 | 0.189 | 15.18 |
#[10, 40000], k=100 | 2.857 | 0.222 | 12.87 |
#[10, 40000], k=4000 | 3.227 | 0.609 | 5.30 |
#[10, 40000], k=20000 | 3.970 | 0.668 | 5.95 |
#[10, 40000], k=39995 | 3.255 | 0.609 | 5.35 |
#[10, 320000], k=10 | 22.597 | 2.487 | 9.09 |
#[10, 320000], k=50 | 22.468 | 2.499 | 8.99 |
#[10, 320000], k=100 | 22.553 | 2.517 | 8.96 |
#[10, 320000], k=32000 | 25.606 | 5.480 | 4.67 |
#[10, 320000], k=160000 | 32.419 | 6.124 | 5.29 |
#[10, 320000], k=319995 | 28.623 | 6.005 | 4.77 |
Table-3: OMP=1, Sorted=True
Input Size | oob | this pr | speed up |
---|---|---|---|
#[10, 10000], k=10 | 0.748 | 0.261 | 2.87 |
#[10, 10000], k=50 | 0.766 | 0.391 | 1.96 |
#[10, 10000], k=100 | 0.788 | 0.550 | 1.43 |
#[10, 10000], k=1000 | 1.255 | 1.296 | 0.97 |
#[10, 10000], k=5000 | 3.554 | 3.441 | 1.03 |
#[10, 10000], k=9995 | 6.185 | 5.710 | 1.08 |
#[10, 40000], k=10 | 2.877 | 0.933 | 3.08 |
#[10, 40000], k=50 | 2.875 | 1.112 | 2.59 |
#[10, 40000], k=100 | 2.895 | 1.282 | 2.26 |
#[10, 40000], k=4000 | 5.184 | 5.304 | 0.98 |
#[10, 40000], k=20000 | 15.905 | 15.106 | 1.05 |
#[10, 40000], k=39995 | 27.970 | 25.741 | 1.09 |
#[10, 320000], k=10 | 23.036 | 7.914 | 2.91 |
#[10, 320000], k=50 | 22.857 | 8.181 | 2.79 |
#[10, 320000], k=100 | 23.075 | 8.404 | 2.75 |
#[10, 320000], k=32000 | 45.292 | 46.478 | 0.97 |
#[10, 320000], k=160000 | 146.232 | 140.205 | 1.04 |
#[10, 320000], k=319995 | 263.640 | 244.572 | 1.08 |
Table-4: OMP=1, Sorted=False
Input Size | oob | this pr | speed up |
---|---|---|---|
#[10, 10000], k=10 | 0.747 | 0.260 | 2.87 |
#[10, 10000], k=50 | 0.749 | 0.389 | 1.92 |
#[10, 10000], k=100 | 0.758 | 0.548 | 1.38 |
#[10, 10000], k=1000 | 0.845 | 0.933 | 0.91 |
#[10, 10000], k=5000 | 1.037 | 1.132 | 0.92 |
#[10, 10000], k=9995 | 0.848 | 0.951 | 0.89 |
#[10, 40000], k=10 | 2.856 | 0.935 | 3.06 |
#[10, 40000], k=50 | 2.863 | 1.090 | 2.63 |
#[10, 40000], k=100 | 2.860 | 1.284 | 2.23 |
#[10, 40000], k=4000 | 3.231 | 3.556 | 0.91 |
#[10, 40000], k=20000 | 3.975 | 4.330 | 0.92 |
#[10, 40000], k=39995 | 3.247 | 3.590 | 0.90 |
#[10, 320000], k=10 | 22.570 | 7.943 | 2.84 |
#[10, 320000], k=50 | 22.504 | 8.143 | 2.76 |
#[10, 320000], k=100 | 22.489 | 8.407 | 2.68 |
#[10, 320000], k=32000 | 25.558 | 29.042 | 0.88 |
#[10, 320000], k=160000 | 32.357 | 36.160 | 0.89 |
#[10, 320000], k=319995 | 28.541 | 31.501 | 0.91 |
std::partial_sort
(heap sort) is quite fast whenK
is small,std::sort
(intro sort) is faster thenK
is large.
- Use
std::partial_sort
whenK
is of small range, no mattersorted
isTrue
orFalse
- Use
std::nth_element
+std::sort
whenK
is of large range andsorted
isTrue
- Use
std::nth_element
whenK
is of large range andsorted
isFalse
- remove
embrace_back
and use pre-allocation forstd::vector
- inline comparator lambda: gcc has trouble properly inline the lambda in case using condition expression (even with
-O3
), e.g.auto comp = cond ? lambda1 : lambda2
is marginally slower than written the lambda insidestd::sort
.
- caffe: use
std::partial_sort
, no parallelization. - mxnet: use
std::partial_sort
when K < C/8, otherwise usestd::sort
, parallel with OpenMP. - tensorflow: minimal heap, no inter parallelization.
- cntk: use
std::partial_sort
, only parallelized when K=1
- when
K
is of small range andN
smaller than number of physical cores, parallel only onN
dimension won't utilize all cores. For example, in transformer beam search, typical input size[N, C]
andN < 10
. Additional performance gain is possible:
- Step1: reorder
input
to be[N, S, C/S]
and perform parallel topk onN * S
dimension. - Step2: output from step1 is
[N, S*K]
, sort onS*K
to find topk values on each channel.
NB: try this in mlperf transformer training...
- SIMD sort: some reference design: avx2-sort, avx512-sort.
NB: try avx512 quick select...
attach raw logs: original topk()