Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active July 2, 2019 02:43
Show Gist options
  • Save mingfeima/115e88f9196447db2d91f8772c924f31 to your computer and use it in GitHub Desktop.
Save mingfeima/115e88f9196447db2d91f8772c924f31 to your computer and use it in GitHub Desktop.
topk_optimization_backups

backups for PR19736 of topk() performance optimization on CPU.


description

Suppose input tensor has shape of [N, C], performance input.topk(K, sorted=Sorted) for the followings scenarios:

  1. C = 10000, 40000, 320000
  2. K = 10, 50, 100, C/10, C/2, C-5
  3. Test with 20 threads and 1 thread
  4. Test with Sorted=True and Sorted=False

benchmark scripts

  1. run_topk_scale.sh
CORES=`lscpu | grep Core | awk '{print $4}'`
SOCKETS=`lscpu | grep Socket | awk '{print $2}'`
TOTAL_CORES=`expr $CORES \* $SOCKETS`
LAST_CORE=`expr $CORES - 1`

KMP_SETTING="KMP_AFFINITY=granularity=fine,compact,1,0"
KMP_BLOCKTIME=1

PREFIX="numactl --physcpubind=0-$LAST_CORE --membind=0"

export $KMP_SETTING
export KMP_BLOCKTIME=$KMP_BLOCKTIME

echo -e "\n### using $KMP_SETTING"
echo -e "### using KMP_BLOCKTIME=$KMP_BLOCKTIME"
echo -e "### using $PREFIX\n"

### single socket test
echo -e "\n### using OMP_NUM_THREADS=$CORES"
OMP_NUM_THREADS=$CORES $PREFIX python -u test_topk.py

### single thread test
echo -e "\n### using OMP_NUM_THREADS=1"
OMP_NUM_THREADS=1 $PREFIX python -u test_topk.py
  1. test_topk.py
import torch
from time import time


def bench_topk(N=8, C=168560, K=10, Sorted=True, num_iters=1000):
    a = torch.randn(N, C)
    for i in range(int(num_iters/10)):
        torch.topk(a, K)
    
    t = 0
    for i in range(num_iters):
        a = torch.randn(N, C)
        start = time()
        value, indice = torch.topk(a, K, sorted=Sorted)
        t += time() - start

    print("#[%d, %d], k=%d, sorted=%s times: %f ms" % (N, C, K,
         ('True' if Sorted else 'False'), t / num_iters * 1000))


def benc_topk_scale(Sorted):
    Ns = [10]
    Cs = [10000, 40000, 320000]

    for n in Ns:
        for c in Cs:
            for k in [10, 50, 100, int(c/10), int(c/2), int(c-5)]:
                iters = 500 if k > 5000 else 1000
                bench_topk(n, c, k, Sorted, iters)

bench_topk_scale(True)
bench_topk_scale(False)

performance results

Tested on Intel Xeon 6148, 20x2 cores @ 2.5GHz, to reproduce results:

./run_topk_scale.sh

All numbers are reported in ms, oob refers original topk performance

Table-1: OMP=20, Sorted=True

Input Size oob this pr speed up
#[10, 10000], k=10 0.745 0.061 12.13
#[10, 10000], k=50 0.765 0.077 9.93
#[10, 10000], k=100 0.789 0.096 8.25
#[10, 10000], k=1000 1.252 0.219 5.72
#[10, 10000], k=5000 3.552 0.504 7.05
#[10, 10000], k=9995 6.185 0.814 7.60
#[10, 40000], k=10 2.882 0.176 16.37
#[10, 40000], k=50 2.895 0.193 14.98
#[10, 40000], k=100 2.914 0.222 13.10
#[10, 40000], k=4000 5.213 0.835 6.24
#[10, 40000], k=20000 15.811 2.019 7.83
#[10, 40000], k=39995 27.980 3.231 8.66
#[10, 320000], k=10 22.928 2.492 9.20
#[10, 320000], k=50 22.835 2.498 9.14
#[10, 320000], k=100 22.859 2.508 9.11
#[10, 320000], k=32000 45.197 7.523 6.01
#[10, 320000], k=160000 146.211 17.432 8.39
#[10, 320000], k=319995 263.868 29.179 9.04

Table-2: OMP=20, Sorted=False

Input Size oob this pr speed up
#[10, 10000], k=10 0.746 0.061 12.20
#[10, 10000], k=50 0.752 0.077 9.74
#[10, 10000], k=100 0.756 0.096 7.88
#[10, 10000], k=1000 0.847 0.172 4.93
#[10, 10000], k=5000 1.038 0.186 5.58
#[10, 10000], k=9995 0.848 0.171 4.95
#[10, 40000], k=10 2.841 0.177 16.06
#[10, 40000], k=50 2.866 0.189 15.18
#[10, 40000], k=100 2.857 0.222 12.87
#[10, 40000], k=4000 3.227 0.609 5.30
#[10, 40000], k=20000 3.970 0.668 5.95
#[10, 40000], k=39995 3.255 0.609 5.35
#[10, 320000], k=10 22.597 2.487 9.09
#[10, 320000], k=50 22.468 2.499 8.99
#[10, 320000], k=100 22.553 2.517 8.96
#[10, 320000], k=32000 25.606 5.480 4.67
#[10, 320000], k=160000 32.419 6.124 5.29
#[10, 320000], k=319995 28.623 6.005 4.77

Table-3: OMP=1, Sorted=True

Input Size oob this pr speed up
#[10, 10000], k=10 0.748 0.261 2.87
#[10, 10000], k=50 0.766 0.391 1.96
#[10, 10000], k=100 0.788 0.550 1.43
#[10, 10000], k=1000 1.255 1.296 0.97
#[10, 10000], k=5000 3.554 3.441 1.03
#[10, 10000], k=9995 6.185 5.710 1.08
#[10, 40000], k=10 2.877 0.933 3.08
#[10, 40000], k=50 2.875 1.112 2.59
#[10, 40000], k=100 2.895 1.282 2.26
#[10, 40000], k=4000 5.184 5.304 0.98
#[10, 40000], k=20000 15.905 15.106 1.05
#[10, 40000], k=39995 27.970 25.741 1.09
#[10, 320000], k=10 23.036 7.914 2.91
#[10, 320000], k=50 22.857 8.181 2.79
#[10, 320000], k=100 23.075 8.404 2.75
#[10, 320000], k=32000 45.292 46.478 0.97
#[10, 320000], k=160000 146.232 140.205 1.04
#[10, 320000], k=319995 263.640 244.572 1.08

Table-4: OMP=1, Sorted=False

Input Size oob this pr speed up
#[10, 10000], k=10 0.747 0.260 2.87
#[10, 10000], k=50 0.749 0.389 1.92
#[10, 10000], k=100 0.758 0.548 1.38
#[10, 10000], k=1000 0.845 0.933 0.91
#[10, 10000], k=5000 1.037 1.132 0.92
#[10, 10000], k=9995 0.848 0.951 0.89
#[10, 40000], k=10 2.856 0.935 3.06
#[10, 40000], k=50 2.863 1.090 2.63
#[10, 40000], k=100 2.860 1.284 2.23
#[10, 40000], k=4000 3.231 3.556 0.91
#[10, 40000], k=20000 3.975 4.330 0.92
#[10, 40000], k=39995 3.247 3.590 0.90
#[10, 320000], k=10 22.570 7.943 2.84
#[10, 320000], k=50 22.504 8.143 2.76
#[10, 320000], k=100 22.489 8.407 2.68
#[10, 320000], k=32000 25.558 29.042 0.88
#[10, 320000], k=160000 32.357 36.160 0.89
#[10, 320000], k=319995 28.541 31.501 0.91

optimization strategy

  1. std::partial_sort (heap sort) is quite fast when K is small, std::sort (intro sort) is faster then K is large.
  • Use std::partial_sort when K is of small range, no matter sorted is True or False
  • Use std::nth_element + std::sortwhen K is of large range and sorted is True
  • Use std::nth_element when K is of large range and sorted is False
  1. remove embrace_back and use pre-allocation for std::vector
  2. inline comparator lambda: gcc has trouble properly inline the lambda in case using condition expression (even with -O3), e.g. auto comp = cond ? lambda1 : lambda2 is marginally slower than written the lambda inside std::sort.

reference designs of topk from other frameworks

  1. caffe: use std::partial_sort, no parallelization.
  2. mxnet: use std::partial_sort when K < C/8, otherwise use std::sort, parallel with OpenMP.
  3. tensorflow: minimal heap, no inter parallelization.
  4. cntk: use std::partial_sort, only parallelized when K=1

TODOs

  1. when K is of small range and N smaller than number of physical cores, parallel only on N dimension won't utilize all cores. For example, in transformer beam search, typical input size [N, C] and N < 10. Additional performance gain is possible:
  • Step1: reorder input to be [N, S, C/S] and perform parallel topk on N * S dimension.
  • Step2: output from step1 is [N, S*K], sort on S*K to find topk values on each channel.

NB: try this in mlperf transformer training...

  1. SIMD sort: some reference design: avx2-sort, avx512-sort.

NB: try avx512 quick select...

@mingfeima
Copy link
Author

attach raw logs: original topk()

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
#[10, 10000], k=10, sorted=True times: 0.745392 ms
#[10, 10000], k=50, sorted=True times: 0.765218 ms
#[10, 10000], k=100, sorted=True times: 0.788699 ms
#[10, 10000], k=1000, sorted=True times: 1.252482 ms
#[10, 10000], k=5000, sorted=True times: 3.551789 ms
#[10, 10000], k=9995, sorted=True times: 6.185414 ms
#[10, 40000], k=10, sorted=True times: 2.881847 ms
#[10, 40000], k=50, sorted=True times: 2.895264 ms
#[10, 40000], k=100, sorted=True times: 2.913851 ms
#[10, 40000], k=4000, sorted=True times: 5.213053 ms
#[10, 40000], k=20000, sorted=True times: 15.811110 ms
#[10, 40000], k=39995, sorted=True times: 27.980268 ms
#[10, 320000], k=10, sorted=True times: 22.928153 ms
#[10, 320000], k=50, sorted=True times: 22.835157 ms
#[10, 320000], k=100, sorted=True times: 22.858513 ms
#[10, 320000], k=32000, sorted=True times: 45.196541 ms
#[10, 320000], k=160000, sorted=True times: 146.211212 ms
#[10, 320000], k=319995, sorted=True times: 263.868076 ms
#[10, 10000], k=10, sorted=False times: 0.746027 ms
#[10, 10000], k=50, sorted=False times: 0.751978 ms
#[10, 10000], k=100, sorted=False times: 0.755853 ms
#[10, 10000], k=1000, sorted=False times: 0.846578 ms
#[10, 10000], k=5000, sorted=False times: 1.037599 ms
#[10, 10000], k=9995, sorted=False times: 0.847579 ms
#[10, 40000], k=10, sorted=False times: 2.841000 ms
#[10, 40000], k=50, sorted=False times: 2.865758 ms
#[10, 40000], k=100, sorted=False times: 2.856599 ms
#[10, 40000], k=4000, sorted=False times: 3.226810 ms
#[10, 40000], k=20000, sorted=False times: 3.970210 ms
#[10, 40000], k=39995, sorted=False times: 3.254811 ms
#[10, 320000], k=10, sorted=False times: 22.596778 ms
#[10, 320000], k=50, sorted=False times: 22.468020 ms
#[10, 320000], k=100, sorted=False times: 22.552507 ms
#[10, 320000], k=32000, sorted=False times: 25.605933 ms
#[10, 320000], k=160000, sorted=False times: 32.418515 ms
#[10, 320000], k=319995, sorted=False times: 28.622736 ms

### using OMP_NUM_THREADS=1
#[10, 10000], k=10, sorted=True times: 0.748208 ms
#[10, 10000], k=50, sorted=True times: 0.765639 ms
#[10, 10000], k=100, sorted=True times: 0.787945 ms
#[10, 10000], k=1000, sorted=True times: 1.255485 ms
#[10, 10000], k=5000, sorted=True times: 3.554268 ms
#[10, 10000], k=9995, sorted=True times: 6.184577 ms
#[10, 40000], k=10, sorted=True times: 2.876680 ms
#[10, 40000], k=50, sorted=True times: 2.874824 ms
#[10, 40000], k=100, sorted=True times: 2.895125 ms
#[10, 40000], k=4000, sorted=True times: 5.184333 ms
#[10, 40000], k=20000, sorted=True times: 15.905116 ms
#[10, 40000], k=39995, sorted=True times: 27.970144 ms
#[10, 320000], k=10, sorted=True times: 23.035637 ms
#[10, 320000], k=50, sorted=True times: 22.857202 ms
#[10, 320000], k=100, sorted=True times: 23.075466 ms
#[10, 320000], k=32000, sorted=True times: 45.292016 ms
#[10, 320000], k=160000, sorted=True times: 146.232030 ms
#[10, 320000], k=319995, sorted=True times: 263.640265 ms
#[10, 10000], k=10, sorted=False times: 0.747023 ms
#[10, 10000], k=50, sorted=False times: 0.749144 ms
#[10, 10000], k=100, sorted=False times: 0.758173 ms
#[10, 10000], k=1000, sorted=False times: 0.845423 ms
#[10, 10000], k=5000, sorted=False times: 1.036735 ms
#[10, 10000], k=9995, sorted=False times: 0.848307 ms
#[10, 40000], k=10, sorted=False times: 2.855686 ms
#[10, 40000], k=50, sorted=False times: 2.862642 ms
#[10, 40000], k=100, sorted=False times: 2.859756 ms
#[10, 40000], k=4000, sorted=False times: 3.231069 ms
#[10, 40000], k=20000, sorted=False times: 3.975193 ms
#[10, 40000], k=39995, sorted=False times: 3.246888 ms
#[10, 320000], k=10, sorted=False times: 22.570211 ms
#[10, 320000], k=50, sorted=False times: 22.503910 ms
#[10, 320000], k=100, sorted=False times: 22.488813 ms
#[10, 320000], k=32000, sorted=False times: 25.557902 ms
#[10, 320000], k=160000, sorted=False times: 32.357234 ms
#[10, 320000], k=319995, sorted=False times: 28.540786 ms

@mingfeima
Copy link
Author

mingfeima commented May 23, 2019

attach raw logs: this pr

(pytorch-mingfei) [mingfeim@mlt-skx091 unit_tests]$ ./run_topk_scale.sh

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
#[10, 10000], k=10, sorted=True times: 0.061469 ms
#[10, 10000], k=50, sorted=True times: 0.077071 ms
#[10, 10000], k=100, sorted=True times: 0.095619 ms
#[10, 10000], k=1000, sorted=True times: 0.218930 ms
#[10, 10000], k=5000, sorted=True times: 0.504025 ms
#[10, 10000], k=9995, sorted=True times: 0.813704 ms
#[10, 40000], k=10, sorted=True times: 0.176027 ms
#[10, 40000], k=50, sorted=True times: 0.193223 ms
#[10, 40000], k=100, sorted=True times: 0.222411 ms
#[10, 40000], k=4000, sorted=True times: 0.835276 ms
#[10, 40000], k=20000, sorted=True times: 2.019146 ms
#[10, 40000], k=39995, sorted=True times: 3.230850 ms
#[10, 320000], k=10, sorted=True times: 2.492267 ms
#[10, 320000], k=50, sorted=True times: 2.497881 ms
#[10, 320000], k=100, sorted=True times: 2.508057 ms
#[10, 320000], k=32000, sorted=True times: 7.522835 ms
#[10, 320000], k=160000, sorted=True times: 17.431903 ms
#[10, 320000], k=319995, sorted=True times: 29.179085 ms
#[10, 10000], k=10, sorted=False times: 0.061143 ms
#[10, 10000], k=50, sorted=False times: 0.077233 ms
#[10, 10000], k=100, sorted=False times: 0.095950 ms
#[10, 10000], k=1000, sorted=False times: 0.171860 ms
#[10, 10000], k=5000, sorted=False times: 0.186076 ms
#[10, 10000], k=9995, sorted=False times: 0.171076 ms
#[10, 40000], k=10, sorted=False times: 0.176849 ms
#[10, 40000], k=50, sorted=False times: 0.188809 ms
#[10, 40000], k=100, sorted=False times: 0.222004 ms
#[10, 40000], k=4000, sorted=False times: 0.609249 ms
#[10, 40000], k=20000, sorted=False times: 0.667706 ms
#[10, 40000], k=39995, sorted=False times: 0.608668 ms
#[10, 320000], k=10, sorted=False times: 2.486966 ms
#[10, 320000], k=50, sorted=False times: 2.499162 ms
#[10, 320000], k=100, sorted=False times: 2.517440 ms
#[10, 320000], k=32000, sorted=False times: 5.480134 ms
#[10, 320000], k=160000, sorted=False times: 6.123717 ms
#[10, 320000], k=319995, sorted=False times: 6.004757 ms

### using OMP_NUM_THREADS=1
#[10, 10000], k=10, sorted=True times: 0.261058 ms
#[10, 10000], k=50, sorted=True times: 0.391242 ms
#[10, 10000], k=100, sorted=True times: 0.549815 ms
#[10, 10000], k=1000, sorted=True times: 1.295901 ms
#[10, 10000], k=5000, sorted=True times: 3.440945 ms
#[10, 10000], k=9995, sorted=True times: 5.709789 ms
#[10, 40000], k=10, sorted=True times: 0.932881 ms
#[10, 40000], k=50, sorted=True times: 1.111758 ms
#[10, 40000], k=100, sorted=True times: 1.281902 ms
#[10, 40000], k=4000, sorted=True times: 5.304386 ms
#[10, 40000], k=20000, sorted=True times: 15.106279 ms
#[10, 40000], k=39995, sorted=True times: 25.740827 ms
#[10, 320000], k=10, sorted=True times: 7.913611 ms
#[10, 320000], k=50, sorted=True times: 8.181231 ms
#[10, 320000], k=100, sorted=True times: 8.403773 ms
#[10, 320000], k=32000, sorted=True times: 46.477951 ms
#[10, 320000], k=160000, sorted=True times: 140.205234 ms
#[10, 320000], k=319995, sorted=True times: 244.571767 ms
#[10, 10000], k=10, sorted=False times: 0.259998 ms
#[10, 10000], k=50, sorted=False times: 0.389469 ms
#[10, 10000], k=100, sorted=False times: 0.548086 ms
#[10, 10000], k=1000, sorted=False times: 0.932514 ms
#[10, 10000], k=5000, sorted=False times: 1.132186 ms
#[10, 10000], k=9995, sorted=False times: 0.950595 ms
#[10, 40000], k=10, sorted=False times: 0.934675 ms
#[10, 40000], k=50, sorted=False times: 1.089834 ms
#[10, 40000], k=100, sorted=False times: 1.284119 ms
#[10, 40000], k=4000, sorted=False times: 3.556359 ms
#[10, 40000], k=20000, sorted=False times: 4.330447 ms
#[10, 40000], k=39995, sorted=False times: 3.590395 ms
#[10, 320000], k=10, sorted=False times: 7.943152 ms
#[10, 320000], k=50, sorted=False times: 8.142707 ms
#[10, 320000], k=100, sorted=False times: 8.406508 ms
#[10, 320000], k=32000, sorted=False times: 29.042225 ms
#[10, 320000], k=160000, sorted=False times: 36.159599 ms
#[10, 320000], k=319995, sorted=False times: 31.500615 ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment