Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created February 3, 2023 08:37
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rygorous/4212be0cd009584e4184e641ca210528 to your computer and use it in GitHub Desktop.
Save rygorous/4212be0cd009584e4184e641ca210528 to your computer and use it in GitHub Desktop.
Multigetbits, the second
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <smmintrin.h>
#ifdef __RADAVX__
#include <immintrin.h>
#endif
#if !defined(__clang__) && defined(_MSC_VER)
#include <intrin.h>
static inline uint16_t bswap16(uint16_t x) { return _byteswap_ushort(x); }
static inline uint32_t bswap32(uint32_t x) { return _byteswap_ulong(x); }
static inline uint64_t bswap64(uint64_t x) { return _byteswap_uint64(x); }
static inline uint64_t rotl64(uint64_t x, uint32_t k) { return _rotl64(x, k); }
#else
static inline uint16_t bswap16(uint16_t x) { return __builtin_bswap16(x); }
static inline uint32_t bswap32(uint32_t x) { return __builtin_bswap32(x); }
static inline uint64_t bswap64(uint64_t x) { return __builtin_bswap64(x); }
static inline uint64_t rrRotlVar64(uint64_t x, uint32_t k) { __asm__("rolq %%cl, %0" : "+r"(x) : "c"(k)); return x; }
#define rotl64(u64,num) (__builtin_constant_p((num)) ? ( ( (u64) << (num) ) | ( (u64) >> (64 - (num))) ) : rrRotlVar64((u64),(num)))
#endif
static inline __m128i prefix_sum_u8(__m128i x)
{
#if 1
// alternative form that uses shifts, not the general shuffle network on port 5 (which is a bottleneck
// for us)
x = _mm_add_epi8(x, _mm_slli_epi64(x, 8));
x = _mm_add_epi8(x, _mm_slli_epi64(x, 16));
x = _mm_add_epi8(x, _mm_slli_epi64(x, 32));
x = _mm_add_epi8(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 7,7,7,7,7,7,7,7)));
#else
// x[0], x[1], x[2], x[3], ...
x = _mm_add_epi8(x, _mm_slli_si128(x, 1));
// x[0], sum(x[0:2]), sum(x[1:3]), sum(x[2:4]), ...
x = _mm_add_epi8(x, _mm_slli_si128(x, 2));
// x[0], sum(x[0:2]), sum(x[0:3]), sum(x[0:4]), sum(x[1:5]), sum(x[2:6]), ...
x = _mm_add_epi8(x, _mm_slli_si128(x, 4));
// longest group now sums over 8 elems
x = _mm_add_epi8(x, _mm_slli_si128(x, 8));
#endif
// and now we're done
return x;
}
static inline __m128i prefix_sum_u16(__m128i x)
{
#if 1
x = _mm_add_epi16(x, _mm_slli_epi64(x, 16));
x = _mm_add_epi16(x, _mm_slli_epi64(x, 32));
x = _mm_add_epi16(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 6,7,6,7,6,7,6,7)));
#else
x = _mm_add_epi16(x, _mm_slli_si128(x, 2));
x = _mm_add_epi16(x, _mm_slli_si128(x, 4));
x = _mm_add_epi16(x, _mm_slli_si128(x, 8));
#endif
return x;
}
static inline __m128i prefix_sum_u32(__m128i x)
{
#if 1
x = _mm_add_epi32(x, _mm_slli_epi64(x, 32));
x = _mm_add_epi32(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 4,5,6,7,4,5,6,7)));
#else
// x[0], x[1], x[2], x[3]
x = _mm_add_epi32(x, _mm_slli_si128(x, 4));
// x[0], sum(x[0:2]), sum(x[1:3]), sum(x[2:4])
x = _mm_add_epi32(x, _mm_slli_si128(x, 8));
// x[0], sum(x[0:2]), sum(x[0:3]), sum(x[0:4])
#endif
return x;
}
// individual field_widths in [0,8]
// MSB-first bit packing convention, SSSE3+
//
// compiled with /arch:AVX (to get rid of reg-reg moves Jaguar code wouldn't have):
// ballpark is ~42 ops for the 16 getbits, so ~2.63 ops/getbits.
//
// so expect maybe 1 cycle/lane on the big cores, 1.7 cycles/lane on Jaguar. (!)
static inline __m128i multigetbits8(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
// prefix-sum the field widths and advance bit position pointer
__m128i summed_widths = prefix_sum_u8(field_widths);
uint32_t total_width = (uint32_t)_mm_extract_epi16(summed_widths, 7) >> 8; // no PEXTRB before SSE4.1, and this is the only place where SSE4.1+ helps
*pbit_basepos = bit_basepos + total_width;
// NOTE once this is done (which is something like 1/4 into the whole thing by op count),
// OoO cores can start working on next iter
// -> this will get good core utilization
// determine starting bit position for every lane
// and split into bit-within-byte and byte indices
__m128i basepos_u8 = _mm_shuffle_epi8(_mm_cvtsi32_si128(bit_basepos & 7), _mm_setzero_si128());
__m128i first_bit_index = _mm_add_epi8(basepos_u8, _mm_slli_si128(summed_widths, 1));
__m128i first_byte_index = _mm_and_si128(_mm_srli_epi16(first_bit_index, 3), _mm_set1_epi8(0x1f)); // no "shift bytes", sigh.
// source bytes
__m128i src_byte0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0));
__m128i src_byte1 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 1));
// first/second bytes for every lane
__m128i byte0 = _mm_shuffle_epi8(src_byte0, first_byte_index);
__m128i byte1 = _mm_shuffle_epi8(src_byte1, first_byte_index);
// assemble words
__m128i words0 = _mm_unpacklo_epi8(byte1, byte0);
__m128i words1 = _mm_unpackhi_epi8(byte1, byte0);
// now, need to shift
// ((byte0<<8) | byte1) >> (16 - width - (first_bit_index & 7))
// we don't have per-lane variable shifts in SSSE3, but we do have PMULHUW,
// and we can do the multiplier table lookup via PSHUFB.
__m128i shift_amt = _mm_add_epi8(_mm_and_si128(first_bit_index, _mm_set1_epi8(7)), field_widths);
__m128i shiftm0_lut = _mm_setr_epi8(0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00);
__m128i shiftm1_lut = _mm_setr_epi8(0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80);
__m128i shiftm0 = _mm_shuffle_epi8(shiftm0_lut, shift_amt);
__m128i shiftm1 = _mm_shuffle_epi8(shiftm1_lut, shift_amt);
__m128i shift_mul0 = _mm_unpacklo_epi8(shiftm0, shiftm1);
__m128i shift_mul1 = _mm_unpackhi_epi8(shiftm0, shiftm1);
__m128i shifted0 = _mm_mulhi_epu16(words0, shift_mul0);
__m128i shifted1 = _mm_mulhi_epu16(words1, shift_mul1);
// pack the results back into bytes
__m128i byte_mask = _mm_set1_epi16(0xff);
__m128i shifted_bytes = _mm_packus_epi16(_mm_and_si128(shifted0, byte_mask), _mm_and_si128(shifted1, byte_mask));
// mask by field width, again using a PSHUFB LUT
__m128i width_mask_lut = _mm_setr_epi8(0,1,3,7, 15,31,63,127, -1,-1,-1,-1, -1,-1,-1,-1);
__m128i width_mask = _mm_shuffle_epi8(width_mask_lut, field_widths);
__m128i result = _mm_and_si128(shifted_bytes, width_mask);
return result;
}
static inline __m128i big_endian_load_shift128(const uint8_t *in_ptr, uint32_t bit_basepos)
{
// Grab 128 source bits starting "bit_basepos" bits into the bit stream
__m128i src_byte0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0));
__m128i src_byte1 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 1));
// We need to consume the first (bit_basepos & 7) bits with a big-endian 128-bit
// funnel shift, which we don't have ready at hand, so we need to get creative;
// specifically, use 16-bit shifts by a single distance for all lanes (which we have)
// and make sure to not grab any bits that crossed byte boundaries (which would be
// taken from the wrong byte due to the endianness difference)
uint32_t basepos7 = bit_basepos & 7;
__m128i basepos_shiftamt = _mm_cvtsi32_si128(basepos7);
// Combine to big-endian 16-bit words and shift those since we don't have 8-bit shifts
// at hand
__m128i merged0 = _mm_unpacklo_epi8(src_byte1, src_byte0);
__m128i merged1 = _mm_unpackhi_epi8(src_byte1, src_byte0);
__m128i shifted0 = _mm_sll_epi16(merged0, basepos_shiftamt);
__m128i shifted1 = _mm_sll_epi16(merged1, basepos_shiftamt);
__m128i reduced0 = _mm_srli_epi16(shifted0, 8);
__m128i reduced1 = _mm_srli_epi16(shifted1, 8);
__m128i shifted_src_bytes = _mm_packus_epi16(reduced0, reduced1);
return shifted_src_bytes;
}
// Once more, with feeling
// trying to figure out a niftier way to do this that'll also allow me do to full multigetbits16
// and multigetbits32 which don't suck
static inline __m128i multigetbits8b(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
// Prefix-sum the field widths and advance bit position pointer
__m128i end_bit_index = prefix_sum_u8(field_widths);
uint32_t total_width = (uint32_t)_mm_extract_epi16(end_bit_index, 7) >> 8; // no PEXTRB before SSE4.1, and this is the only place where SSE4.1+ helps
*pbit_basepos = bit_basepos + total_width;
// Doing this shift is a bit of a production, but it simplifies the rest.
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos);
__m128i end_byte_index = _mm_and_si128(_mm_srli_epi16(end_bit_index, 3), _mm_set1_epi8(0x1f)); // no "shift bytes", sigh.
// Grab first/second bytes for every lane
__m128i byte1 = _mm_shuffle_epi8(shifted_src_bytes, end_byte_index);
__m128i byte0 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(end_byte_index, _mm_set1_epi8(1)));
// Assemble words (byte1 << 8) | byte0
__m128i words0 = _mm_unpacklo_epi8(byte1, byte0);
__m128i words1 = _mm_unpackhi_epi8(byte1, byte0);
// Now do a left shift by 1 << (end_bit_index & 7) using a multiply,
// putting the end of the bit field at the boundary between the low and high byte
// in every word.
__m128i end_bit_index7 = _mm_and_si128(end_bit_index, _mm_set1_epi8(7));
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128);
__m128i shiftm = _mm_shuffle_epi8(left_shift_lut, end_bit_index7);
__m128i shift_mul0 = _mm_unpacklo_epi8(shiftm, _mm_setzero_si128());
__m128i shift_mul1 = _mm_unpackhi_epi8(shiftm, _mm_setzero_si128());
__m128i shifted0 = _mm_mullo_epi16(words0, shift_mul0);
__m128i shifted1 = _mm_mullo_epi16(words1, shift_mul1);
// Grab the high byte of the results and pack back into bytes
__m128i shifted_bytes = _mm_packus_epi16(_mm_srli_epi16(shifted0, 8), _mm_srli_epi16(shifted1, 8));
// mask by field width, again using a PSHUFB LUT
__m128i width_mask_lut = _mm_setr_epi8(0,1,3,7, 15,31,63,127, -1,-1,-1,-1, -1,-1,-1,-1);
__m128i width_mask = _mm_shuffle_epi8(width_mask_lut, field_widths);
__m128i result = _mm_and_si128(shifted_bytes, width_mask);
return result;
}
static inline __m128i multigetbits_leftshift_mult(__m128i end_bit_index)
{
#if 1
// This requires 0 <= end_bit_index <= 127!
// We use that PSHUFB only looks at the bottom 4 bits for the index, plus bit 7 to decide whether to
// substitute in zero.
//
// Since end_bit_index < 128, we know bit 7 is clear, so we don't need to AND with 7. Just replicate
// the 8-entry table twice.
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128);
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, end_bit_index), _mm_set1_epi32(0xff));
#else
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_set1_epi32(0x3f800000));
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic));
#endif
return left_shift_mult;
}
// field widths here are U32[4] in [0,24]
static inline __m128i multigetbits24a(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_widths = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi16(summed_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+)
*pbit_basepos = bit_basepos + total_width;
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00);
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths);
// say bit_basepos = 3 and field_widths[0] = 11
// then end_bit_index[0] = 3 + 11 = 14
//
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant
// byte position of lane 0
//
// this is byte 1, so we want shuffle[0] = 14>>3 = 1
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be
// flush with bit 8 of lane 0.
//
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up
// starting from bit 8 of the target lane is exactly what we want.
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle);
// left shift the source dwords
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index);
__m128i shifted_dwords = _mm_mullo_epi32(dwords, left_shift_mult);
// right shift by 8 to align it to the bottom
__m128i finished_bit_grab = _mm_srli_epi32(shifted_dwords, 8);
// create width mask constants
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000));
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic));
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// field widths here are given packed little-endian into an U32
static inline __m128i multigetbits24b(const uint8_t *in_ptr, uint32_t *pbit_basepos, uint32_t packed_widths)
{
uint32_t bit_basepos = *pbit_basepos;
// use a multiply do to the inclusive prefix sum
uint32_t field_end = (packed_widths + (bit_basepos & 7)) * 0x01010101u;
*pbit_basepos = (bit_basepos & ~7) + (field_end >> 24);
__m128i widths_vec = _mm_cvtepu8_epi32(_mm_cvtsi32_si128(packed_widths));
__m128i end_bit_index = _mm_cvtepu8_epi32(_mm_cvtsi32_si128(field_end));
// say bit_basepos = 3 and field_widths[0] = 11
// then end_bit_index[0] = 3 + 11 = 14
//
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant
// byte position of lane 0
//
// this is byte 1, so we want shuffle[0] = 14>>3 = 1
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be
// flush with bit 8 of lane 0.
//
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up
// starting from bit 8 of the target lane is exactly what we want.
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle);
// left shift the source dwords
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index);
__m128i shifted_dwords = _mm_mullo_epi32(dwords, left_shift_mult);
// right shift by 8 to align it to the bottom
__m128i finished_bit_grab = _mm_srli_epi32(shifted_dwords, 8);
// create width mask constants
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(widths_vec, 23), _mm_set1_epi32(0xbf800000));
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic));
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,30].
// Limit here is 30 so that we consume at most 30*4 + 7 (for the initial align) = 127 bits from the source
// any more turns out to get messy
static inline __m128i multigetbits30(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_widths = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi16(summed_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+)
*pbit_basepos = bit_basepos + total_width;
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00);
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths);
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i dwords0 = _mm_shuffle_epi8(src_bytes, byte_shuffle);
__m128i dwords1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// left shift the source dwords
// The high concept here is that the "l" values contain the low bits of the result, and
// the 'h' values contain the high bits of the result.
//
// The top approach computes this with 16-bit multiplies which are usually faster,
// but this requires a slightly more complicated setup for the multipliers.
//
// The bottom approach just uses 32-bit multiplies.
#if 1
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001)));
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic));
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult);
#else
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index);
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult);
#endif
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh);
// create width mask constants
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000));
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic));
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,32].
// with the big_endian_load_shift128 primitive, we can support 32 bits in every lane
static inline __m128i multigetbits32(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i end_bit_index = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi16(end_bit_index, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+)
*pbit_basepos = bit_basepos + total_width;
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos);
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i dwords0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle);
__m128i dwords1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// left shift the source dwords
// The high concept here is that the "l" values contain the low bits of the result, and
// the 'h' values contain the high bits of the result.
//
// The top approach computes this with 16-bit multiplies which are usually faster,
// but this requires a slightly more complicated setup for the multipliers.
//
// The bottom approach just uses 32-bit multiplies.
#if 1
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001)));
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic));
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult);
#else
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index);
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult);
#endif
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh);
// create width mask constants
// supporting width=32 here adds an extra wrinkle
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000));
__m128i width_mask0 = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic));
__m128i width_gt31 = _mm_cmpgt_epi32(field_widths, _mm_set1_epi32(31));
__m128i width_mask = _mm_andnot_si128(width_gt31, width_mask0);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
static inline __m128i wideload_shuffle(__m128i src0_128, __m128i src1_32, __m128i shuf_index)
{
__m128i lower = _mm_shuffle_epi8(src0_128, shuf_index);
__m128i upper = _mm_andnot_si128(_mm_cmpgt_epi8(shuf_index, _mm_set1_epi8(-1)), src1_32);
return _mm_or_si128(lower, upper);
//__m128i upper = _mm_shuffle_epi8(src1_32, _mm_xor_si128(shuf_index, _mm_set1_epi8(0x83 - 0x100)));
//return _mm_or_si128(lower, upper);
}
// Field widths are [0,32].
// alternative approach without laod_shift128, instead using a different load strategy
//
// interesting but worse
static inline __m128i multigetbits32c(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_field_widths = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi16(summed_field_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+)
*pbit_basepos = bit_basepos + total_width;
__m128i end_bit_index = _mm_add_epi32(summed_field_widths, _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0));
__m128i src_bytes0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0));
__m128i src_bytes1 = _mm_cvtsi32_si128(*(int *) (in_ptr + (bit_basepos >> 3) + 13)); // MOVD
src_bytes1 = _mm_shuffle_epi8(src_bytes1, _mm_set1_epi8(3)); // broadcast final byte
__m128i end_byte_index = _mm_add_epi32(_mm_srli_epi32(end_bit_index, 3), _mm_set1_epi32(0x70));
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i dwords0 = wideload_shuffle(src_bytes0, src_bytes1, byte_shuffle);
__m128i dwords1 = wideload_shuffle(src_bytes0, src_bytes1, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// left shift the source dwords
// The high concept here is that the "l" values contain the low bits of the result, and
// the 'h' values contain the high bits of the result.
//
// The top approach computes this with 16-bit multiplies which are usually faster,
// but this requires a slightly more complicated setup for the multipliers.
//
// The bottom approach just uses 32-bit multiplies.
#if 1
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001)));
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic));
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult);
#else
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index);
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult);
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult);
#endif
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh);
// create width mask constants
// supporting width=32 here adds an extra wrinkle
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000));
__m128i width_mask0 = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic));
__m128i width_gt31 = _mm_cmpgt_epi32(field_widths, _mm_set1_epi32(31));
__m128i width_mask = _mm_andnot_si128(width_gt31, width_mask0);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,15].
// Limit is 15 so that we consume at most 15*8 + 7 (for the initial align) = 127 bits from the source
// any more gets messy
static inline __m128i multigetbits15(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_widths = prefix_sum_u16(field_widths);
uint32_t total_width = _mm_extract_epi16(summed_widths, 7);
*pbit_basepos = bit_basepos + total_width;
__m128i basepos_u16 = _mm_set1_epi16(bit_basepos & 7);
__m128i end_bit_index = _mm_add_epi16(basepos_u16, summed_widths);
__m128i end_byte_index = _mm_srli_epi16(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,2,2, 4,4,6,6, 8,8,10,10, 12,12,14,14));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_set1_epi16(0x0100));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i words0 = _mm_shuffle_epi8(src_bytes, byte_shuffle);
__m128i words1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// left shift the source words
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128);
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, end_bit_index), _mm_set1_epi16(0xff));
__m128i shifted_words0 = _mm_mullo_epi16(words0, left_shift_mult);
__m128i shifted_words1 = _mm_mullo_epi16(words1, left_shift_mult);
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi16(shifted_words0, 8), shifted_words1);
// create width mask constants
__m128i width_exps = _mm_add_epi16(_mm_slli_epi16(field_widths, 7), _mm_set1_epi16(0xbf80));
__m128i zero = _mm_setzero_si128();
__m128i widthm0 = _mm_cvttps_epi32(_mm_castsi128_ps(_mm_unpacklo_epi16(zero, width_exps)));
__m128i widthm1 = _mm_cvttps_epi32(_mm_castsi128_ps(_mm_unpackhi_epi16(zero, width_exps)));
__m128i width_mask = _mm_packs_epi32(widthm0, widthm1);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,16].
// with the big_endian_load_shift128 primitive, we can support 16 bits in every lane
static inline __m128i multigetbits16(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i end_bit_index = prefix_sum_u16(field_widths);
uint32_t total_width = _mm_extract_epi16(end_bit_index, 7);
*pbit_basepos = bit_basepos + total_width;
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos);
__m128i end_byte_index = _mm_srli_epi16(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,2,2, 4,4,6,6, 8,8,10,10, 12,12,14,14));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_set1_epi16(0x0100));
// shuffle source bytes
__m128i words0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle);
__m128i words1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// left shift the source words
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128);
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, _mm_and_si128(end_bit_index, _mm_set1_epi8(7))), _mm_set1_epi16(0xff));
__m128i shifted_words0 = _mm_mullo_epi16(words0, left_shift_mult);
__m128i shifted_words1 = _mm_mullo_epi16(words1, left_shift_mult);
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi16(shifted_words0, 8), shifted_words1);
// create width mask
// need to do this differently from multigetbits15 logic here to make width=16 work
__m128i base_mask_lut = _mm_setr_epi8(-1,-2,-4,-8, -16,-32,-64,-128, -1,-2,-4,-8, -16,-32,-64,-128);
__m128i width_mask0 = _mm_shuffle_epi8(base_mask_lut, field_widths); // gives (-1 << (field_widths & 7))
__m128i width_mask0s = _mm_slli_epi16(width_mask0, 8);
__m128i width_gt7 = _mm_cmpgt_epi16(field_widths, _mm_set1_epi16(7));
// conditionally shift by 8 where field_widths >= 8
__m128i width_mask1 = _mm_or_si128(_mm_and_si128(width_mask0s, width_gt7), _mm_andnot_si128(width_gt7, width_mask0));
// conditionally zero mask where field_widths >= 16
__m128i width_gt15 = _mm_cmpgt_epi16(field_widths, _mm_set1_epi16(15));
__m128i width_mask = _mm_andnot_si128(width_gt15, width_mask1);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
#ifdef __RADAVX__
// field widths here are U32[4] in [0,24]
static inline __m128i multigetbits24a_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_widths = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi32(summed_widths, 3);
*pbit_basepos = bit_basepos + total_width;
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00);
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths);
// say bit_basepos = 3 and field_widths[0] = 11
// then end_bit_index[0] = 3 + 11 = 14
//
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant
// byte position of lane 0
//
// this is byte 1, so we want shuffle[0] = 14>>3 = 1
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be
// flush with bit 8 of lane 0.
//
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up
// starting from bit 8 of the target lane is exactly what we want.
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle);
// right shift the source dwords to align the corect bits at the bottom
__m128i shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), _mm_and_si128(end_bit_index, _mm_set1_epi32(7)));
__m128i finished_bit_grab = _mm_srlv_epi32(dwords, shift_amt);
// mask to desired field widths
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,30].
// Limit here is 30 so that we consume at most 30*4 + 7 (for the initial align) = 127 bits from the source
// any more turns out to get messy
static inline __m128i multigetbits30_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i summed_widths = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi32(summed_widths, 3);
*pbit_basepos = bit_basepos + total_width;
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00);
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths);
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// grab source bytes and shuffle
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3)));
__m128i dwords0 = _mm_shuffle_epi8(src_bytes, byte_shuffle);
__m128i dwords1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// shift the source dwords
// The high concept here is that the "l" values contain the low bits of the result, and
// the 'h' values contain the high bits of the result.
__m128i shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i rev_shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), shift_amt);
__m128i shifted_dwordsl = _mm_srlv_epi32(dwords0, rev_shift_amt);
__m128i shifted_dwordsh = _mm_sllv_epi32(dwords1, shift_amt);
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(shifted_dwordsl, shifted_dwordsh);
// mask to desired field widths
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
// Field widths are [0,32].
static inline __m128i multigetbits32_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths)
{
uint32_t bit_basepos = *pbit_basepos;
__m128i end_bit_index = prefix_sum_u32(field_widths);
uint32_t total_width = _mm_extract_epi32(end_bit_index, 3);
*pbit_basepos = bit_basepos + total_width;
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos);
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3);
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12));
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3));
// shuffle source bytes
__m128i dwords0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle);
__m128i dwords1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1)));
// shift the source dwords
// The high concept here is that the "l" values contain the low bits of the result, and
// the 'h' values contain the high bits of the result.
__m128i shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7));
__m128i rev_shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), shift_amt);
__m128i shifted_dwordsl = _mm_srlv_epi32(dwords0, rev_shift_amt);
__m128i shifted_dwordsh = _mm_sllv_epi32(dwords1, shift_amt);
// combine the low and high parts
__m128i finished_bit_grab = _mm_or_si128(shifted_dwordsl, shifted_dwordsh);
// mask to desired field widths
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths);
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab);
return masked_fields;
}
#endif
// ---- testbed
static void decode8_ref(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
static const uint8_t masks[9] = { 0,1,3,7, 15,31,63,127, 255 };
// we just gleefully over-read; not worrying about that right now.
uint64_t bits = 0, bitc = 0;
for (size_t i = 0; i < count; i += 7)
{
// refill bit buffer (so it contains at least 56 bits)
uint64_t bytes_consumed = (63 - bitc) >> 3;
bits |= bswap64(*(const uint64_t *) in_ptr) >> bitc;
bitc |= 56;
in_ptr += bytes_consumed;
// decode 7 values
uint32_t w;
uint64_t mask;
#if 1
// !!better!! (by 0.5 cycles/elem on SNB!)
#define DECONE(ind) \
w = width_arr[i + ind]; \
bits = rotl64(bits, w); \
mask = masks[w] & bits; \
out_ptr[i + ind] = (uint8_t) mask; \
bits ^= mask; \
bitc -= w
#else
#define DECONE(ind) \
w = width_arr[i + ind]; \
bits = rotl64(bits, w); \
mask = masks[w]; \
out_ptr[i + ind] = (uint8_t) (bits & mask); \
bits &= ~mask; \
bitc -= w
#endif
DECONE(0);
DECONE(1);
DECONE(2);
DECONE(3);
DECONE(4);
DECONE(5);
DECONE(6);
#undef DECONE
}
}
static void decode8_SSSE3(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 16)
{
__m128i widths = _mm_loadu_si128((const __m128i *) (width_arr + i));
__m128i values = multigetbits8(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode8b_SSSE3(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 16)
{
__m128i widths = _mm_loadu_si128((const __m128i *) (width_arr + i));
__m128i values = multigetbits8b(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode16_ref(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
static const uint16_t masks[17] = { 0x0000, 0x0001,0x0003,0x0007,0x000f, 0x001f,0x003f,0x007f,0x00ff, 0x01ff,0x03ff,0x07ff,0x0fff, 0x1fff,0x3fff,0x7fff,0xffff };
// we just gleefully over-read; not worrying about that right now.
uint64_t bits = 0, bitc = 0;
for (size_t i = 0; i < count; i += 3)
{
// refill bit buffer (so it contains at least 56 bits)
uint64_t bytes_consumed = (63 - bitc) >> 3;
bits |= bswap64(*(const uint64_t *) in_ptr) >> bitc;
bitc |= 56;
in_ptr += bytes_consumed;
// decode 3 values
uint32_t w;
uint64_t mask;
#define DECONE(ind) \
w = width_arr[i + ind]; \
bits = rotl64(bits, w); \
mask = masks[w] & bits; \
out_ptr[i + ind] = (uint16_t) mask; \
bits ^= mask; \
bitc -= w
DECONE(0);
DECONE(1);
DECONE(2);
#undef DECONE
}
}
static void decode16_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
__m128i byteswap_shuf = _mm_setr_epi8(1,0, 3,2, 5,4, 7,6, 9,8, 11,10, 13,12, 15,14);
for (size_t i = 0; i < count; i += 8)
{
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i));
widths = _mm_unpacklo_epi8(widths, widths);
widths = _mm_subs_epu8(widths, _mm_set1_epi16(0x0008));
widths = _mm_min_epu8(widths, _mm_set1_epi8(8));
__m128i values = multigetbits8(in_ptr, &bit_pos, widths);
values = _mm_shuffle_epi8(values, byteswap_shuf);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode16b_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 8)
{
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i));
widths = _mm_unpacklo_epi8(widths, _mm_setzero_si128());
__m128i values = multigetbits16(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode15_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 8)
{
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i));
widths = _mm_unpacklo_epi8(widths, _mm_setzero_si128());
__m128i values = multigetbits15(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode32_ref(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
static const uint32_t masks[33] =
{
0x00000000,
0x00000001, 0x00000003, 0x00000007, 0x0000000f,
0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff,
0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff,
0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff,
0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff,
0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff,
0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff,
0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff,
};
// we just gleefully over-read; not worrying about that right now.
uint64_t bitc = 0;
for (size_t i = 0; i < count; i++)
{
// grab value
uint64_t bits = bswap64(*(const uint64_t *) (in_ptr + (bitc >> 3)));
uint32_t w = width_arr[i];
out_ptr[i] = rotl64(bits, w + (bitc & 7)) & masks[w];
bitc += w;
}
}
static void decode32_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
__m128i broadcast_shuf = _mm_setr_epi8(0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3);
__m128i byteswap_shuf = _mm_setr_epi8(3,2,1,0, 7,6,5,4, 11,10,9,8, 15,14,13,12);
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_shuffle_epi8(widths, broadcast_shuf);
widths = _mm_subs_epu8(widths, _mm_set1_epi32(0x00081018));
widths = _mm_min_epu8(widths, _mm_set1_epi8(8));
__m128i values = multigetbits8(in_ptr, &bit_pos, widths);
values = _mm_shuffle_epi8(values, byteswap_shuf);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode32b_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
__m128i expand_shuf = _mm_setr_epi8(0,-1,-1,-1, 1,-1,-1,-1, 2,-1,-1,-1, 3,-1,-1,-1);
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_shuffle_epi8(widths, expand_shuf);
__m128i values = multigetbits32(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode32c_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
__m128i expand_shuf = _mm_setr_epi8(0,-1,-1,-1, 1,-1,-1,-1, 2,-1,-1,-1, 3,-1,-1,-1);
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_shuffle_epi8(widths, expand_shuf);
__m128i values = multigetbits32c(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode24_ref(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
static const uint32_t masks[33] =
{
0x00000000,
0x00000001, 0x00000003, 0x00000007, 0x0000000f,
0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff,
0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff,
0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff,
0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff,
0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff,
0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff,
0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff,
};
// we just gleefully over-read; not worrying about that right now.
uint64_t bitc = 0;
for (size_t i = 0; i < count; i += 2)
{
// grab value
uint64_t bits = bswap64(*(const uint64_t *) in_ptr);
uint32_t w;
w = width_arr[i + 0];
out_ptr[i + 0] = rotl64(bits, w + bitc) & masks[w];
bitc += w;
w = width_arr[i + 1];
out_ptr[i + 1] = rotl64(bits, w + bitc) & masks[w];
bitc += w;
in_ptr += bitc >> 3;
bitc &= 7;
}
}
static void decode24_SSE4_v1(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_cvtepu8_epi32(widths);
__m128i values = multigetbits24a(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode24_SSE4_v2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
uint32_t widths = *(const uint32_t *) (width_arr + i);
__m128i values = multigetbits24b(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode30_SSE4(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_cvtepu8_epi32(widths);
__m128i values = multigetbits30(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
#ifdef __RADAVX__
static void decode24_AVX2_v1(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_cvtepu8_epi32(widths);
__m128i values = multigetbits24a_avx2(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode30_AVX2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_cvtepu8_epi32(widths);
__m128i values = multigetbits30_avx2(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
static void decode32_AVX2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count)
{
uint32_t bit_pos = 0;
for (size_t i = 0; i < count; i += 4)
{
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i));
widths = _mm_cvtepu8_epi32(widths);
__m128i values = multigetbits32_avx2(in_ptr, &bit_pos, widths);
_mm_storeu_si128((__m128i *) (out_ptr + i), values);
}
}
#endif
// ---- RNG (PCG XSH RR 64/32 MCG)
typedef struct {
uint64_t state;
} rng32;
rng32 rng32_seed(uint64_t seed)
{
rng32 r;
// state cannot be 0 (MCG)
// also do one multiply step in case the input is a small integer (which it often is)
r.state = (seed | 1) * 6364136223846793005ULL;
return r;
}
uint32_t rng32_random(rng32 *r)
{
// Generate output from old state
uint64_t oldstate = r->state;
uint32_t rot_input = (uint32_t) (((oldstate >> 18) ^ oldstate) >> 27);
uint32_t rot_amount = (uint32_t) (oldstate >> 59);
uint32_t output = (rot_input >> rot_amount) | (rot_input << ((0u - rot_amount) & 31)); // rotr(rot_input, rot_amount)
// Advance multiplicative congruential generator
// Constant from PCG reference impl.
r->state = oldstate * 6364136223846793005ull;
return output;
}
// ---- test driver
typedef void kernel8_func(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count);
typedef void kernel16_func(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count);
typedef void kernel32_func(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count);
static inline uint32_t cycle_timer()
{
#ifdef _MSC_VER
return (uint32_t)__rdtsc();
#else
uint32_t lo, hi;
asm volatile ("rdtsc" : "=a" (lo), "=d" (hi) );
return lo;
#endif
}
int g_sink = 0; // to prevent DCE
static int int_compare(const void *a, const void *b)
{
int ai = *(const int *)a;
int bi = *(const int *)b;
return (ai > bi) - (ai < bi);
}
static uint32_t run_test8(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel8_func *func)
{
uint32_t start_tsc = cycle_timer();
func(out_ptr, width_arr, in_ptr, count);
uint32_t cycle_count = cycle_timer() - start_tsc;
return cycle_count;
}
static void testone8(char const *name, kernel8_func *func)
{
static const size_t kNumValues = 7*16*128; // a bit under 16k, divisible by both 16 and 7
//static const size_t kNumValues = 7*16;
static const size_t kNumPadded = kNumValues + 16;
uint8_t out1_bytes[kNumPadded] = {};
uint8_t out2_bytes[kNumPadded] = {};
uint8_t in_bytes[kNumPadded] = {};
uint8_t widths[kNumPadded] = {};
// set up some test data
rng32 rng = rng32_seed(83467);
for (size_t i = 0; i < kNumValues; i++)
{
uint32_t val = rng32_random(&rng);
in_bytes[i] = val & 0xff;
widths[i] = (val >> 8) % 9; // values in 0..8
}
// verify by comparing against ref
run_test8(out1_bytes, widths, in_bytes, kNumValues, func);
run_test8(out2_bytes, widths, in_bytes, kNumValues, decode8_ref);
if (memcmp(out1_bytes, out2_bytes, kNumValues) != 0)
{
size_t i = 0;
while (out1_bytes[i] == out2_bytes[i])
++i;
printf("%20s: MISMATCH! (at byte %d)\n", name, (int)i);
return;
}
// warm-up
static const size_t nWarmupRuns = 100;
for (size_t run = 0; run < nWarmupRuns; ++run)
g_sink += run_test8(out1_bytes, widths, in_bytes, kNumValues, func);
// benchmark
static const size_t nRuns = 10000;
int *run_lens = new int[nRuns];
for (size_t run = 0; run < nRuns; ++run)
run_lens[run] = run_test8(out1_bytes, widths, in_bytes, kNumValues, func);
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare);
double ratio = 1.0 / kNumValues;
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio);
delete[] run_lens;
}
static uint32_t run_test16(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel16_func *func)
{
uint32_t start_tsc = cycle_timer();
func(out_ptr, width_arr, in_ptr, count);
uint32_t cycle_count = cycle_timer() - start_tsc;
return cycle_count;
}
static void testone16(char const *name, kernel16_func *func, int max_width)
{
static const size_t kNumValues = 3*8*128; // a bit under 16k bytes, divisible by both 8 and 3
static const size_t kNumPadded = kNumValues + 16;
uint16_t out1_buf[kNumPadded] = {};
uint16_t out2_buf[kNumPadded] = {};
uint8_t in_bytes[kNumPadded*2] = {};
uint8_t widths[kNumPadded] = {};
// set up some test data
rng32 rng = rng32_seed(83467);
for (size_t i = 0; i < kNumValues; i++)
{
uint32_t val = rng32_random(&rng);
in_bytes[i*2+0] = val & 0xff;
in_bytes[i*2+1] = (val >> 8) & 0xff;
widths[i] = (val >> 16) % (max_width + 1);
}
// verify by comparing against ref
run_test16(out1_buf, widths, in_bytes, kNumValues, func);
run_test16(out2_buf, widths, in_bytes, kNumValues, decode16_ref);
if (memcmp(out1_buf, out2_buf, kNumValues * sizeof(uint16_t)) != 0)
{
size_t i = 0;
while (out1_buf[i] == out2_buf[i])
++i;
printf("%20s: MISMATCH! (at word %d)\n", name, (int)i);
return;
}
// warm-up
static const size_t nWarmupRuns = 100;
for (size_t run = 0; run < nWarmupRuns; ++run)
g_sink += run_test16(out1_buf, widths, in_bytes, kNumValues, func);
// benchmark
static const size_t nRuns = 10000;
int *run_lens = new int[nRuns];
for (size_t run = 0; run < nRuns; ++run)
run_lens[run] = run_test16(out1_buf, widths, in_bytes, kNumValues, func);
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare);
double ratio = 1.0 / kNumValues;
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio);
delete[] run_lens;
}
static uint32_t run_test32(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel32_func *func)
{
uint32_t start_tsc = cycle_timer();
func(out_ptr, width_arr, in_ptr, count);
uint32_t cycle_count = cycle_timer() - start_tsc;
return cycle_count;
}
static void testone32(char const *name, kernel32_func *func, int max_width)
{
static const size_t kNumValues = 3*8*128; // a bit under 16k bytes, divisible by both 8 and 3
static const size_t kNumPadded = kNumValues + 16;
uint32_t out1_buf[kNumPadded] = {};
uint32_t out2_buf[kNumPadded] = {};
uint8_t in_bytes[kNumPadded*2] = {};
uint8_t widths[kNumPadded] = {};
// set up some test data
rng32 rng = rng32_seed(83467);
for (size_t i = 0; i < kNumValues; i++)
{
uint32_t val = rng32_random(&rng);
in_bytes[i*2+0] = val & 0xff;
in_bytes[i*2+1] = (val >> 8) & 0xff;
widths[i] = (val >> 16) % (max_width + 1);
}
// verify by comparing against ref
run_test32(out1_buf, widths, in_bytes, kNumValues, func);
run_test32(out2_buf, widths, in_bytes, kNumValues, decode32_ref);
if (memcmp(out1_buf, out2_buf, kNumValues * sizeof(uint32_t)) != 0)
{
size_t i = 0;
while (out1_buf[i] == out2_buf[i])
++i;
printf("%20s: MISMATCH! (at word %d)\n", name, (int)i);
return;
}
// warm-up
static const size_t nWarmupRuns = 1000;
for (size_t run = 0; run < nWarmupRuns; ++run)
g_sink += run_test32(out1_buf, widths, in_bytes, kNumValues, func);
// benchmark
static const size_t nRuns = 30000;
int *run_lens = new int[nRuns];
for (size_t run = 0; run < nRuns; ++run)
run_lens[run] = run_test32(out1_buf, widths, in_bytes, kNumValues, func);
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare);
double ratio = 1.0 / kNumValues;
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio);
delete[] run_lens;
}
int main()
{
#define TESTIT8(what) testone8(#what, what)
#define TESTIT16(what, width) testone16(#what, what, width)
#define TESTIT32(what, width) testone32(#what, what, width)
TESTIT8(decode8_ref);
TESTIT8(decode8_SSSE3);
TESTIT8(decode8b_SSSE3);
TESTIT16(decode15_SSSE3, 15);
TESTIT16(decode16_ref, 16);
TESTIT16(decode16_SSSE3, 16);
TESTIT16(decode16b_SSSE3, 16);
TESTIT32(decode24_ref, 24);
TESTIT32(decode24_SSE4_v1, 24);
TESTIT32(decode24_SSE4_v2, 24);
TESTIT32(decode30_SSE4, 30);
TESTIT32(decode32_ref, 32);
TESTIT32(decode32_SSSE3, 32);
TESTIT32(decode32b_SSSE3, 32);
TESTIT32(decode32c_SSSE3, 32);
#ifdef __RADAVX__
TESTIT32(decode24_AVX2_v1, 24);
TESTIT32(decode30_AVX2, 30);
TESTIT32(decode32_AVX2, 30);
#endif
#undef TESTIT8
#undef TESTIT16
#undef TESTIT32
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment