Skip to content

Instantly share code, notes, and snippets.

@aqrit
Last active May 4, 2024 21:09
Show Gist options
  • Save aqrit/79a76ef29046b2d42eafc6b1eb0bb518 to your computer and use it in GitHub Desktop.
Save aqrit/79a76ef29046b2d42eafc6b1eb0bb518 to your computer and use it in GitHub Desktop.
#include <stdint.h>
#include <immintrin.h>
// credit: YumiYumiYumi
// (fixed by aqrit)
__m128i _mm_tzcnt_epi32(__m128i v) {
__m128i mask = _mm_set1_epi32(0xffffff81);
v = _mm_and_si128(v, _mm_sign_epi32(v, mask));
v = _mm_castps_si128(_mm_cvtepi32_ps(v));
v = _mm_srli_epi32(v, 23);
v = _mm_add_epi32(v, mask);
v = _mm_min_epu8(v, _mm_set1_epi32(32));
return v;
}
// returns 0 for input of 0x00000000
// returns 0 for input of 0x00000001
// returns 31 for input of 0x80000000
__m128i _mm_bsf_epi32(__m128i v) {
// isolate lowest set bit
v = _mm_and_si128(v, _mm_sub_epi32(_mm_setzero_si128(), v));
v = _mm_castps_si128(_mm_cvtepi32_ps(v)); // convert int to float
v = _mm_srli_epi32(v, 23); // shift down the exponent and sign bit
v = _mm_subs_epu8(v,_mm_set1_epi32(0x017F)); // undo bias and sign bit
return v;
}
// https://stackoverflow.com/a/58827596
__m256i avx2_lzcnt_epi32 (__m256i v) {
// prevent value from being rounded up to the next power of two
v = _mm256_andnot_si256(_mm256_srli_epi32(v, 8), v); // keep 8 MSB
v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float
v = _mm256_srli_epi32(v, 23); // shift down the exponent
v = _mm256_subs_epu16(_mm256_set1_epi32(158), v); // undo bias
v = _mm256_min_epi16(v, _mm256_set1_epi32(32)); // clamp at 32
return v;
}
// Credit: YumiYumiYumi
// https://old.reddit.com/r/simd/comments/b3k1oa/looking_for_sseavx_bitscan_discussions/
__m256i avx2_lzcnt2_epi32(__m256i v) {
const __m256i lut_lo = _mm256_set_epi8(
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 32,
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 32
);
const __m256i lut_hi = _mm256_set_epi8(
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 32,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 32
);
const __m256i nibble_mask = _mm256_set1_epi8(0x0F);
const __m256i byte_offset = _mm256_set1_epi32(0x00081018);
__m256i t;
/* find lzcnt for each byte */
t = _mm256_and_si256(nibble_mask, v);
v = _mm256_and_si256(_mm256_srli_epi16(v, 4), nibble_mask);
t = _mm256_shuffle_epi8(lut_lo, t);
v = _mm256_shuffle_epi8(lut_hi, v);
v = _mm256_min_epu8(v, t);
/* find lzcnt for each dword */
v = _mm256_or_si256(v, byte_offset);
v = _mm256_min_epu8(v, _mm256_srli_epi16(v, 8));
v = _mm256_min_epu8(v, _mm256_srli_epi32(v, 16));
return v;
}
// Credit: YumiYumiYumi
// https://old.reddit.com/r/simd/comments/b3k1oa/looking_for_sseavx_bitscan_discussions/
__m256i avx2_tzcnt_epi32(__m256i v) {
const __m256i lut_lo = _mm256_set_epi8(
0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 32,
0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 32
);
const __m256i lut_hi = _mm256_set_epi8(
4, 5, 4, 6, 4, 5, 4, 7, 4, 5, 4, 6, 4, 5, 4, 32,
4, 5, 4, 6, 4, 5, 4, 7, 4, 5, 4, 6, 4, 5, 4, 32
);
const __m256i nibble_mask = _mm256_set1_epi8(0x0F);
const __m256i byte_offset = _mm256_set1_epi32(0x18100800);
__m256i t;
/* find tzcnt for each byte */
t = _mm256_and_si256(nibble_mask, v);
v = _mm256_and_si256(_mm256_srli_epi16(v, 4), nibble_mask);
t = _mm256_shuffle_epi8(lut_lo, t);
v = _mm256_shuffle_epi8(lut_hi, v);
v = _mm256_min_epu8(v, t);
/* find tzcnt for each dword */
v = _mm256_or_si256(v, byte_offset);
v = _mm256_min_epu8(v, _mm256_srli_epi16(v, 8));
v = _mm256_min_epu8(v, _mm256_srli_epi32(v, 16));
return v;
}
// 16 - lzcnt_u16(subwords)
__m256i avx2_ms1b_epi16(__m256i v) {
const __m256i lut_lo = _mm256_set_epi8(
12, 12, 12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 10, 10, 9, 0,
12, 12, 12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 10, 10, 9, 0
);
const __m256i lut_hi = _mm256_set_epi8(
16, 16, 16, 16, 16, 16, 16, 16, 15, 15, 15, 15, 14, 14, 13, 0,
16, 16, 16, 16, 16, 16, 16, 16, 15, 15, 15, 15, 14, 14, 13, 0
);
const __m256i nibble_mask = _mm256_set1_epi8(0x0F);
const __m256i adj = _mm256_set1_epi16(0x1F08);
__m256i t;
t = _mm256_and_si256(nibble_mask, v);
v = _mm256_and_si256(_mm256_srli_epi16(v, 4), nibble_mask);
t = _mm256_shuffle_epi8(lut_lo, t);
v = _mm256_shuffle_epi8(lut_hi, v);
v = _mm256_max_epu8(v, t);
t = _mm256_srli_epi16(v, 8);
v = _mm256_sub_epi8(v, adj);
v = _mm256_max_epi8(v, t);
return v;
}
// 32 - _lzcnt_u32(subwords)
// returned values are packed into 8-bit subwords
//
// if msb is set then return 32
// if input is zero then return 0
//
// Author: aqrit
#define AVX2_MS1B_PACKED(in0, in1, in2, in3, out) do { \
__m256i r4, r5, r6, r7, r8, r9; \
const __m256i mask_7F = _mm256_set1_epi8(0x7F); \
const __m256i lut_lo = _mm256_set_epi8( \
4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 1, 0, \
4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 1, 0 \
); \
const __m256i lut_hi = _mm256_set_epi8( \
8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 6, 6, 5, 0, \
8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 6, 6, 5, 0 \
); \
\
/* detect zero bytes */ \
r4 = _mm256_cmpeq_epi8(_mm256_setzero_si256(), in0); \
r5 = _mm256_cmpeq_epi8(_mm256_setzero_si256(), in1); \
r6 = _mm256_cmpeq_epi8(_mm256_setzero_si256(), in2); \
r7 = _mm256_cmpeq_epi8(_mm256_setzero_si256(), in3); \
\
/* reduce 32-bit subwords to 8-bits */ \
r4 = _mm256_packs_epi16(r4, r5); \
r6 = _mm256_packs_epi16(r6, r7); \
r4 = _mm256_srai_epi16(r4, 7); /* discard low byte */ \
r6 = _mm256_srai_epi16(r6, 7); \
r4 = _mm256_packs_epi16(r4, r6); \
\
/* calc index of which byte holds the hightest set bit */ \
/* 0x00 -> 0 -> 3, 0x01 -> 0 -> 3, 0x7F -> 0 -> 3, */ \
/* 0x80 -> 1 -> 2, 0xFE -> 7F -> 1, 0xFF -> 80 -> 0 */ \
r4 = _mm256_subs_epu8(r4, mask_7F); \
r5 = _mm256_shuffle_epi8(_mm256_set1_epi32(0x01000203), r4); \
\
/* gather bytes */ \
r4 = _mm256_or_si256(r5, _mm256_set1_epi32(0x0C080400)); \
r6 = _mm256_shuffle_epi8(in0, r4); \
r7 = _mm256_shuffle_epi8(in1, r4); \
r6 = _mm256_blend_epi32(r6, r7, 0x22); \
r8 = _mm256_shuffle_epi8(in2, r4); \
r6 = _mm256_blend_epi32(r6, r8, 0x44); \
r9 = _mm256_shuffle_epi8(in3, r4); \
r6 = _mm256_blend_epi32(r6, r9, 0x88); \
\
/* find the highest set bit within each byte */ \
r7 = _mm256_and_si256(r6, mask_7F); \
r6 = _mm256_srli_epi32(r6, 4); \
r6 = _mm256_and_si256(r6, mask_7F); \
r8 = _mm256_shuffle_epi8(lut_lo, r7); \
r9 = _mm256_shuffle_epi8(lut_hi, r6); \
r8 = _mm256_max_epu8(r8, r9); \
\
/* out = (byte_index * 8) + bit_index */ \
r5 = _mm256_slli_epi32(r5, 3); \
out = _mm256_add_epi8(r5, r8); \
\
/* optional: return results in order */ \
const __m256i interleave = _mm256_set_epi32(7,3,6,2,5,1,4,0); \
out = _mm256_permutevar8x32_epi32(out, interleave); \
} while(0)
// Emulate 32-bit BitScanForward using vpmulld
//
// a subword of 0x00000000 returns as 0
// a subword of 0x00000001 returns as 0
// a subword of 0x80000000 returns as 31
//
// based on: https://graphics.stanford.edu/%7Eseander/bithacks.html#ZerosOnRightMultLookup
// *somewhere* I've seen the translation to SSE before...
#define AVX2_BSF32(v) do { \
const __m256i lut_lo = _mm256_setr_epi8( \
0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, \
0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8 \
); \
const __m256i lut_hi = _mm256_setr_epi8( \
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9, \
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9 \
); \
const __m256i magic = _mm256_set1_epi32(0x077CB531); \
const __m256i flip = _mm256_set1_epi32(0xFFFFFF80); \
__m256i t; \
\
/* isolate lowest set bit */ \
v = _mm256_and_si256(v, _mm256_sign_epi32(v, flip)); \
\
/* min perfect hash */ \
v = _mm256_mullo_epi32(v, magic); \
v = _mm256_srai_epi32(v, 27); \
t = _mm256_shuffle_epi8(lut_lo, v); \
v = _mm256_add_epi8(v, flip); \
v = _mm256_shuffle_epi8(lut_hi, v); \
v = _mm256_or_si256(v, t); \
\
} while(0)
// Emulate 32-bit BitScanForward using xor-folding
// returned values are packed into 8-bit subwords
//
// I'm not sure if I've seen this exact thing somewhere before or not...
#define AVX2_BSF32_PACKED(in0, in1, in2, in3, out) do { \
__m256i r4, r5, r6, r7; \
const __m256i mask_FF = _mm256_set1_epi8(0xFF); \
const __m256i mask_80 = _mm256_set1_epi8(0x80); \
const __m256i lut = _mm256_set_epi8( \
4, 5, 0, 6, 0, 0, 15, 7, 3, 0, 0, 0, 2, 0, 1, 0, \
4, 5, 0, 6, 0, 0, 15, 7, 3, 0, 0, 0, 2, 0, 1, 0 \
); \
const __m256i shuf = _mm256_set_epi8( \
15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1, \
15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1 \
); \
\
/* get mask of trailing zeros (emulate TZMSK) */ \
r4 = _mm256_andnot_si256(in0, _mm256_add_epi32(in0, mask_FF)); \
r5 = _mm256_andnot_si256(in1, _mm256_add_epi32(in1, mask_FF)); \
r6 = _mm256_andnot_si256(in2, _mm256_add_epi32(in2, mask_FF)); \
r7 = _mm256_andnot_si256(in3, _mm256_add_epi32(in3, mask_FF)); \
\
/* XOR-Fold 32-bit words into 16-bit words */ \
r4 = _mm256_hsub_epi16(r4, r5); \
r6 = _mm256_hsub_epi16(r6, r7); \
\
/* XOR-Fold 16-bits words into 8-bit words */ \
r5 = _mm256_xor_si256(r4, _mm256_slli_epi16(r4, 8)); \
r5 = _mm256_srli_epi16(r5, 8); \
r7 = _mm256_xor_si256(r6, _mm256_shuffle_epi8(r6, shuf)); \
r5 = _mm256_packus_epi16(r5, r7); \
\
/* get bits 3 & 4 of the result */ \
r4 = _mm256_and_si256(_mm256_packs_epi16(r4, r6), mask_80); \
r6 = _mm256_srli_epi16(_mm256_and_si256(r5, mask_80) , 1); \
r4 = _mm256_srli_epi16(_mm256_or_si256(r4,r6), 3); \
\
/* XOR-Fold 8-bits words into 4-bits (held in an 8-bit subword) */ \
r5 = _mm256_xor_si256(_mm256_srli_epi16(r5, 4), r5); \
\
/* Use a lookup table to skip remaining XOR-folds. */ \
r5 = _mm256_shuffle_epi8(lut, _mm256_andnot_si256(mask_80, r5)); \
\
/* Merge the high and low parts of the result */ \
out = _mm256_or_si256(r4, r5); \
\
/* optional: return results in order */ \
const __m256i interleave = _mm256_set_epi32(7,3,6,2,5,1,4,0); \
out = _mm256_permutevar8x32_epi32(out, interleave); \
\
} while(0)
// **hasty** conversion of the clz method to ctz... probably not optimal
//
// _tzcnt_u32(subwords)
//
// if lsb is set then return 0
// if input is zero then return 32
//
// Author: aqrit
#define AVX2_TZCNT(in0, in1, in2, in3, out) do { \
__m256i r4, r5, r6, r7, r8, r9; \
const __m256i mask_7F = _mm256_set1_epi8(0x7F); \
const __m256i lut_lo = _mm256_set_epi8( \
0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 8, \
0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 8 \
); \
const __m256i lut_hi = _mm256_set_epi8( \
4, 5, 4, 6, 4, 5, 4, 7, 4, 5, 4, 6, 4, 5, 4, 8, \
4, 5, 4, 6, 4, 5, 4, 7, 4, 5, 4, 6, 4, 5, 4, 8 \
); \
\
/* subtract 1 ...*/ \
const __m256i mask_FF = _mm256_set1_epi8(0xFF); \
r4 = _mm256_add_epi32(in0, mask_FF); \
r5 = _mm256_add_epi32(in1, mask_FF); \
r6 = _mm256_add_epi32(in2, mask_FF); \
r7 = _mm256_add_epi32(in3, mask_FF); \
\
/* detect unchanged bytes */ \
r4 = _mm256_cmpeq_epi8(r4, in0); \
r5 = _mm256_cmpeq_epi8(r5, in1); \
r6 = _mm256_cmpeq_epi8(r6, in2); \
r7 = _mm256_cmpeq_epi8(r7, in3); \
\
/* reduce 32-bit subwords to 8-bits */ \
r4 = _mm256_packs_epi16(r4, r5); \
r6 = _mm256_packs_epi16(r6, r7); \
r4 = _mm256_srai_epi16(r4, 7); /* discard low byte */ \
r6 = _mm256_srai_epi16(r6, 7); \
r4 = _mm256_packs_epi16(r4, r6); \
\
/* calc index of which byte holds the lowest set bit */ \
/* 0x00 -> 0 -> 3, 0x01 -> 0 -> 3, 0x7F -> 0 -> 3, */ \
/* 0x80 -> 1 -> 2, 0xFE -> 7F -> 1, 0xFF -> 80 -> 0 */ \
r4 = _mm256_subs_epu8(r4, mask_7F); \
r5 = _mm256_shuffle_epi8(_mm256_set1_epi32(0x01000203), r4); \
\
/* gather bytes */ \
r4 = _mm256_or_si256(r5, _mm256_set1_epi32(0x0C080400)); \
r6 = _mm256_shuffle_epi8(in0, r4); \
r7 = _mm256_shuffle_epi8(in1, r4); \
r6 = _mm256_blend_epi32(r6, r7, 0x22); \
r8 = _mm256_shuffle_epi8(in2, r4); \
r6 = _mm256_blend_epi32(r6, r8, 0x44); \
r9 = _mm256_shuffle_epi8(in3, r4); \
r6 = _mm256_blend_epi32(r6, r9, 0x88); \
\
/* find the lowest set bit within each byte */ \
r7 = _mm256_and_si256(r6, mask_7F); \
r6 = _mm256_srli_epi32(r6, 4); \
r6 = _mm256_and_si256(r6, mask_7F); \
r8 = _mm256_shuffle_epi8(lut_lo, r7); \
r9 = _mm256_shuffle_epi8(lut_hi, r6); \
r8 = _mm256_min_epu8(r8, r9); \
\
/* out = (byte_index * 8) + bit_index */ \
r5 = _mm256_slli_epi32(r5, 3); \
out = _mm256_add_epi8(r5, r8); \
\
/* optional: return results in order */ \
const __m256i interleave = _mm256_set_epi32(7,3,6,2,5,1,4,0); \
out = _mm256_permutevar8x32_epi32(out, interleave); \
} while(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment