Skip to content

Instantly share code, notes, and snippets.

@primenumber
Last active June 14, 2023 18:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save primenumber/a3508fd08db5c8cc73302699762e835f to your computer and use it in GitHub Desktop.
Save primenumber/a3508fd08db5c8cc73302699762e835f to your computer and use it in GitHub Desktop.
Parallel bit scan reverse
// author: prime (prime@kmc.gr.jp)
// License: MIT License
#include <iostream>
#include <vector>
#include <bitset>
#include <x86intrin.h>
#include <boost/timer/timer.hpp>
inline __m256i bsr_256_8_naive(__m256i x);
inline __m256i bsr_256_8_cvtfloat(__m256i x);
inline __m256i bsr_256_8_popcnt(__m256i x);
inline __m256i bsr_256_8_table(__m256i x);
inline __m256i bsr_256_16_naive(__m256i x);
inline __m256i bsr_256_16_cvtfloat(__m256i x);
inline __m256i bsr_256_16_popcnt(__m256i x);
inline __m256i bsr_256_16_table(__m256i x);
inline __m256i bsr_256_32_naive(__m256i x);
inline __m256i bsr_256_32_popcnt(__m256i x);
inline __m256i bsr_256_32_cvtfloat(__m256i x);
inline __m256i bsr_256_32_table_gather(__m256i x);
inline __m256i bsr_256_64_naive(__m256i x);
inline __m256i bsr_256_64_cvtfloat(__m256i x);
inline __m256i bsr_256_64_cvtfloat2(__m256i x);
inline __m256i bsr_256_64_rev_popcnt(__m256i x);
inline __m256i bsr_256_32_cvtfloat_impl(__m256i x, int32_t sub) {
__m256i cvt_fl = _mm256_castps_si256(_mm256_cvtepi32_ps(x));
__m256i shifted = _mm256_srli_epi32(cvt_fl, 23);
return _mm256_sub_epi32(shifted, _mm256_set1_epi32(sub));
}
inline __m256i bsr_256_8_naive(__m256i x) {
alignas(32) uint8_t b[32];
_mm256_store_si256((__m256i*)b, x);
return _mm256_setr_epi8(
__bsrd(b[ 0]), __bsrd(b[ 1]), __bsrd(b[ 2]), __bsrd(b[ 3]),
__bsrd(b[ 4]), __bsrd(b[ 5]), __bsrd(b[ 6]), __bsrd(b[ 7]),
__bsrd(b[ 8]), __bsrd(b[ 9]), __bsrd(b[10]), __bsrd(b[11]),
__bsrd(b[12]), __bsrd(b[13]), __bsrd(b[14]), __bsrd(b[15]),
__bsrd(b[16]), __bsrd(b[17]), __bsrd(b[18]), __bsrd(b[19]),
__bsrd(b[20]), __bsrd(b[21]), __bsrd(b[22]), __bsrd(b[23]),
__bsrd(b[24]), __bsrd(b[25]), __bsrd(b[26]), __bsrd(b[27]),
__bsrd(b[28]), __bsrd(b[29]), __bsrd(b[30]), __bsrd(b[31]));
}
inline __m256i bsr_256_8_cvtfloat(__m256i x) {
__m256i r0 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi32(0x000000FF)), 127);
__m256i r1 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi32(x, 8), _mm256_set1_epi32(0x000000FF)), 127);
__m256i r2 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi32(x, 16), _mm256_set1_epi32(0x000000FF)), 127);
__m256i r3 = bsr_256_32_cvtfloat_impl(_mm256_srli_epi32(x, 24), 127);
__m256i r02 = _mm256_blend_epi16(r0, _mm256_slli_epi32(r2, 16), 0xAA);
__m256i r13 = _mm256_blend_epi16(r1, _mm256_slli_epi32(r3, 16), 0xAA);
return _mm256_blendv_epi8(r02, _mm256_slli_epi16(r13, 8), _mm256_set1_epi16(0xFF00));
}
inline __m256i popcount_256_8(__m256i x) {
x = _mm256_sub_epi8(x, _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi8(0x55)));
x = _mm256_add_epi8(_mm256_and_si256(x, _mm256_set1_epi8(0x33)), _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi8(0x33)));
return _mm256_and_si256(_mm256_add_epi8(x, _mm256_srli_epi16(x, 4)), _mm256_set1_epi8(0x0F));
}
inline __m256i bsr_256_8_popcnt(__m256i x) {
x = _mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi8(0x7F)));
x = _mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi8(0x3F)));
x = _mm256_and_si256(_mm256_srli_epi16(_mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi8(0x0F))), 1), _mm256_set1_epi8(0x7F));
return popcount_256_8(x);
}
inline __m256i bsr_256_8_table(__m256i x) {
__m128i table128_lo = _mm_setr_epi8(0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3);
__m128i table128_hi = _mm_setr_epi8(0, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7);
__m256i table_lo = _mm256_broadcastsi128_si256(table128_lo);
__m256i table_hi = _mm256_broadcastsi128_si256(table128_hi);
return _mm256_max_epi8(
_mm256_shuffle_epi8(table_lo, _mm256_and_si256(x, _mm256_set1_epi8(0x0F))),
_mm256_shuffle_epi8(table_hi, _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi8(0x0F))));
}
inline __m256i bsr_256_16_naive(__m256i x) {
alignas(32) uint16_t b[16];
_mm256_store_si256((__m256i*)b, x);
return _mm256_setr_epi16(
__bsrd(b[ 0]), __bsrd(b[ 1]), __bsrd(b[ 2]), __bsrd(b[ 3]),
__bsrd(b[ 4]), __bsrd(b[ 5]), __bsrd(b[ 6]), __bsrd(b[ 7]),
__bsrd(b[ 8]), __bsrd(b[ 9]), __bsrd(b[10]), __bsrd(b[11]),
__bsrd(b[12]), __bsrd(b[13]), __bsrd(b[14]), __bsrd(b[15]));
}
inline __m256i popcount_256_16(__m256i x) {
x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x5555)), _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi16(0x5555)));
x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x3333)), _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi16(0x3333)));
x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x0F0F)), _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi16(0x0F0F)));
return _mm256_maddubs_epi16(x, _mm256_set1_epi16(0x0101));
}
inline __m256i bsr_256_16_popcnt(__m256i x) {
x = _mm256_or_si256(x, _mm256_srli_epi16(x, 1));
x = _mm256_or_si256(x, _mm256_srli_epi16(x, 2));
x = _mm256_or_si256(x, _mm256_srli_epi16(x, 4));
x = _mm256_srli_epi16(_mm256_or_si256(x, _mm256_srli_epi16(x, 8)), 1);
return popcount_256_16(x);
}
inline __m256i bsr_256_16_cvtfloat(__m256i x) {
__m256i lo = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi32(0x0000FFFF)), 127);
__m256i hi = bsr_256_32_cvtfloat_impl(_mm256_srli_epi32(x, 16), 127);
return _mm256_blend_epi16(lo, _mm256_slli_epi32(hi, 16), 0xAA);
}
inline __m256i bsr_256_16_table(__m256i x) {
__m256i half = bsr_256_8_table(x);
__m256i mask = _mm256_cmpgt_epi8(x, _mm256_setzero_si256());
__m256i res8 = _mm256_add_epi8(half, _mm256_and_si256(mask, _mm256_set1_epi16(0x0800)));
return _mm256_max_epi16(_mm256_and_si256(res8, _mm256_set1_epi16(0x00FF)), _mm256_srli_epi16(res8, 8));
}
inline __m256i bsr_256_32_naive(__m256i x) {
alignas(32) uint32_t b[8];
_mm256_store_si256((__m256i*)b, x);
return _mm256_setr_epi32(
__bsrd(b[0]), __bsrd(b[1]), __bsrd(b[2]), __bsrd(b[3]),
__bsrd(b[4]), __bsrd(b[5]), __bsrd(b[6]), __bsrd(b[7]));
}
inline __m256i bsr_256_32_cvtfloat(__m256i x) {
x = _mm256_andnot_si256(_mm256_srli_epi32(x, 1), x); // 連続するビット対策
__m256i result = bsr_256_32_cvtfloat_impl(x, 127);
result = _mm256_or_si256(result, _mm256_srai_epi32(x, 31));
return _mm256_and_si256(result, _mm256_set1_epi32(0x0000001F));
}
inline __m256i popcount_256_32(__m256i x) {
x = _mm256_sub_epi32(x, _mm256_and_si256(_mm256_srli_epi32(x, 1), _mm256_set1_epi8(0x55)));
x = _mm256_add_epi32(_mm256_and_si256(x, _mm256_set1_epi8(0x33)), _mm256_and_si256(_mm256_srli_epi32(x, 2), _mm256_set1_epi8(0x33)));
x = _mm256_and_si256(_mm256_add_epi8(x, _mm256_srli_epi32(x, 4)), _mm256_set1_epi8(0x0F));
return _mm256_srli_epi16(_mm256_madd_epi16(x, _mm256_set1_epi16(0x0101)), 8);
}
inline __m256i bsr_256_32_popcnt(__m256i x) {
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 1));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 2));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 4));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 8));
x = _mm256_srli_epi32(_mm256_or_si256(x, _mm256_srli_epi32(x, 16)), 1);
return popcount_256_32(x);
}
int32_t *table_32;
void init_table_32() {
table_32 = (int32_t*)malloc(sizeof(int32_t)*((size_t)1 << 31));
table_32[0] = 0;
int64_t j = 1;
for (int i = 1; i < 32; ++i) {
table_32[j] = i;
j = 2*j + 1;
}
}
inline __m256i bsr_256_32_table_gather(__m256i x) {
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 1));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 2));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 4));
x = _mm256_or_si256(x, _mm256_srli_epi32(x, 8));
x = _mm256_srli_epi32(_mm256_or_si256(x, _mm256_srli_epi32(x, 16)), 1);
return _mm256_i32gather_epi32(table_32, x, 4);
}
inline __m256i bsr_256_64_naive(__m256i x) {
alignas(32) uint64_t b[4];
_mm256_store_si256((__m256i*)b, x);
return _mm256_setr_epi64x(
__bsrq(b[0]), __bsrq(b[1]), __bsrq(b[2]), __bsrq(b[3]));
}
inline __m256i bsr_256_64_cvtfloat(__m256i x) {
__m256i bsr32 = bsr_256_32_cvtfloat(x);
__m256i higher = _mm256_add_epi32(_mm256_srli_epi64(bsr32, 32), _mm256_set1_epi64x(0x0000000000000020));
__m256i mask = _mm256_shuffle_epi32(_mm256_cmpeq_epi32(x, _mm256_setzero_si256()), 0xF5);
return _mm256_blendv_epi8(higher, _mm256_and_si256(bsr32, _mm256_set1_epi64x(0xFFFFFFFF)), mask);
}
inline __m256i bsr_256_64_cvtfloat2(__m256i x) {
__m256i bsr0 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi64x(0x3FFFFF)), 127);
__m256i bsr1 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi64(x, 22), _mm256_set1_epi64x(0x3FFFFF)), 105);
__m256i bsr2 = bsr_256_32_cvtfloat_impl(_mm256_srli_epi64(x, 44), 83);
__m256i result = _mm256_max_epi32(_mm256_max_epi32(bsr0, bsr1), bsr2);
return _mm256_and_si256(result, _mm256_set1_epi64x(0xFFFFFFFF));
}
inline __m256i popcount_256_64(__m256i x) {
x = _mm256_sub_epi64(x, _mm256_and_si256(_mm256_srli_epi64(x, 1), _mm256_set1_epi8(0x55)));
x = _mm256_add_epi64(_mm256_and_si256(x, _mm256_set1_epi8(0x33)),
_mm256_and_si256(_mm256_srli_epi64(x, 2), _mm256_set1_epi8(0x33)));
x = _mm256_and_si256(_mm256_add_epi64(x, _mm256_srli_epi64(x, 4)), _mm256_set1_epi8(0x0F));
return _mm256_sad_epu8(x, _mm256_setzero_si256());
}
inline __m256i bsr_256_64_popcnt(__m256i x) {
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 2));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 4));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 8));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 16));
x = _mm256_srli_epi64(_mm256_or_si256(x, _mm256_srli_epi64(x, 32)), 1);
return popcount_256_64(x);
}
inline __m256i bswap_256_64(__m256i x) {
__m128i swap_table_128 = _mm_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8);
__m256i swap_table = _mm256_broadcastsi128_si256(swap_table_128);
return _mm256_shuffle_epi8(x, swap_table);
}
inline __m256i bsr_256_64_rev_popcnt(__m256i x) {
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 2));
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 4));
x = _mm256_andnot_si256(_mm256_srli_epi64(x, 1), x);
x = bswap_256_64(x);
x = _mm256_and_si256(x, _mm256_sub_epi64(_mm256_setzero_si256(), x));
x = bswap_256_64(x);
return popcount_256_64(_mm256_sub_epi64(x, _mm256_set1_epi64x(1)));
}
#define DEF_BENCH_BSR8(name) \
void bench_bsr_256_8_##name() { \
std::cout << "bsr 8bit "#name << std::endl; \
boost::timer::cpu_timer timer; \
__m256i add_val = _mm256_set1_epi8(32); \
__m256i res = _mm256_setzero_si256(); \
__m256i valinit = _mm256_setr_epi8(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32); \
for (int i = 0; i < (1 << 24); ++i) { \
__m256i vals = valinit; \
__m256i res_part = _mm256_setzero_si256(); \
for (int j = 0; j < (1 << 2); ++j) { \
__m256i bsr_res = bsr_256_8_##name(vals); \
res_part = _mm256_add_epi8(res_part, bsr_res); \
vals = _mm256_add_epi8(vals, add_val); \
} \
/* オーバーフロー対策 */ \
__m256i add_64 = _mm256_sad_epu8(res_part, _mm256_setzero_si256()); \
res = _mm256_add_epi64(res, add_64); \
valinit = _mm256_alignr_epi8(valinit, valinit, 1); \
} \
uint64_t sum = _mm256_extract_epi64(res, 0) + _mm256_extract_epi64(res, 1) + _mm256_extract_epi64(res, 2) + _mm256_extract_epi64(res, 3); \
std::cout << "result: " << sum << ", elapsed: " << timer.format(3, "%ws") << std::endl; \
}
#define DEF_BENCH_BSR16(name) \
void bench_bsr_256_16_##name() { \
std::cout << "bsr 16bit "#name << std::endl; \
boost::timer::cpu_timer timer; \
__m256i add_val = _mm256_set1_epi16(16); \
__m256i res = _mm256_setzero_si256(); \
for (int i = 0; i < (1 << 17); ++i) { \
__m256i vals = _mm256_setr_epi16(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16); \
__m256i res_part = _mm256_setzero_si256(); \
for (int j = 0; j < (1 << 10); ++j) { \
__m256i bsr_res = bsr_256_16_##name(vals); \
res_part = _mm256_add_epi16(res_part, bsr_res); \
vals = _mm256_add_epi16(vals, add_val); \
} \
/* オーバーフロー対策 */ \
__m256i add_32 = _mm256_madd_epi16(res_part, _mm256_set1_epi32(0x00010001)); \
__m256i add_64 = _mm256_add_epi32(_mm256_and_si256(add_32, _mm256_set1_epi64x(0x00000000FFFFFFFF)), _mm256_srli_epi64(add_32, 32)); \
res = _mm256_add_epi64(res, add_64); \
} \
uint64_t sum = _mm256_extract_epi64(res, 0) + _mm256_extract_epi64(res, 1) + _mm256_extract_epi64(res, 2) + _mm256_extract_epi64(res, 3); \
std::cout << "result: " << sum << ", elapsed: " << timer.format(3, "%ws") << std::endl; \
}
#define DEF_BENCH_BSR32(name) \
void bench_bsr_256_32_##name() { \
std::cout << "bsr 32bit "#name << std::endl; \
boost::timer::cpu_timer timer; \
__m256i vals = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8); \
__m256i add_val = _mm256_set1_epi32(8); \
__m256i res = _mm256_setzero_si256(); \
for (int i = 0; i < (1 << 14); ++i) { \
__m256i res_part = _mm256_setzero_si256(); \
for (int j = 0; j < (1 << 14); ++j) { \
__m256i bsr_res = bsr_256_32_##name(vals); \
res_part = _mm256_add_epi32(res_part, bsr_res); \
vals = _mm256_add_epi32(vals, add_val); \
} \
/* オーバーフロー対策 */ \
__m256i add_64 = _mm256_add_epi32(_mm256_and_si256(res_part, _mm256_set1_epi64x(0x00000000FFFFFFFF)), _mm256_srli_epi64(res_part, 32)); \
res = _mm256_add_epi64(res, add_64); \
} \
uint64_t sum = _mm256_extract_epi64(res, 0) + _mm256_extract_epi64(res, 1) + _mm256_extract_epi64(res, 2) + _mm256_extract_epi64(res, 3); \
std::cout << "result: " << sum << ", elapsed: " << timer.format(3, "%ws") << std::endl; \
}
#define DEF_BENCH_BSR64(name) \
void bench_bsr_256_64_##name() { \
std::cout << "bsr 64bit "#name << std::endl; \
boost::timer::cpu_timer timer; \
__m256i vals = _mm256_setr_epi64x(1, 2, 3, 4); \
__m256i add_val = _mm256_set1_epi64x(4); \
__m256i res = _mm256_setzero_si256(); \
for (int i = 0; i < (1 << 29); ++i) { \
__m256i bsr_res = bsr_256_64_##name(vals); \
res = _mm256_add_epi64(res, bsr_res); \
vals = _mm256_add_epi64(vals, add_val); \
} \
uint64_t sum = _mm256_extract_epi64(res, 0) + _mm256_extract_epi64(res, 1) + _mm256_extract_epi64(res, 2) + _mm256_extract_epi64(res, 3); \
std::cout << "result: " << sum << ", elapsed: " << timer.format(3, "%ws") << std::endl; \
}
DEF_BENCH_BSR8(naive);
DEF_BENCH_BSR8(cvtfloat);
DEF_BENCH_BSR8(popcnt);
DEF_BENCH_BSR8(table);
DEF_BENCH_BSR16(naive);
DEF_BENCH_BSR16(cvtfloat);
DEF_BENCH_BSR16(popcnt);
DEF_BENCH_BSR16(table);
DEF_BENCH_BSR32(naive);
DEF_BENCH_BSR32(cvtfloat);
DEF_BENCH_BSR32(popcnt);
DEF_BENCH_BSR32(table_gather);
DEF_BENCH_BSR64(naive);
DEF_BENCH_BSR64(cvtfloat);
DEF_BENCH_BSR64(cvtfloat2);
DEF_BENCH_BSR64(popcnt);
DEF_BENCH_BSR64(rev_popcnt);
int main() {
init_table_32();
std::cout << "Benchmark: parallel 8bit BSR" << std::endl;
bench_bsr_256_8_naive();
bench_bsr_256_8_cvtfloat();
bench_bsr_256_8_popcnt();
bench_bsr_256_8_table();
std::cout << std::endl;
std::cout << "Benchmark: parallel 16bit BSR" << std::endl;
bench_bsr_256_16_naive();
bench_bsr_256_16_cvtfloat();
bench_bsr_256_16_popcnt();
bench_bsr_256_16_table();
std::cout << std::endl;
std::cout << "Benchmark: parallel 32bit BSR" << std::endl;
bench_bsr_256_32_naive();
bench_bsr_256_32_cvtfloat();
bench_bsr_256_32_popcnt();
bench_bsr_256_32_table_gather();
std::cout << std::endl;
std::cout << "Benchmark: parallel 64bit BSR" << std::endl;
bench_bsr_256_64_naive();
bench_bsr_256_64_cvtfloat();
bench_bsr_256_64_cvtfloat2();
bench_bsr_256_64_popcnt();
bench_bsr_256_64_rev_popcnt();
free(table_32);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment