Skip to content

Instantly share code, notes, and snippets.

@aqrit
Last active August 4, 2022 22:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save aqrit/c729815b0165c139d0bac642ab7ee104 to your computer and use it in GitHub Desktop.
Save aqrit/c729815b0165c139d0bac642ab7ee104 to your computer and use it in GitHub Desktop.
#include <stdint.h>
#include <string.h> // memcpy, memset
static inline
uint64_t umul128( uint64_t a, uint64_t b, uint64_t* hi) {
unsigned __int128 x = (unsigned __int128)a * (unsigned __int128)b;
*hi = (uint64_t)(x >> 64);
return (uint64_t)x;
}
static inline
uint64_t loadu_u64 (const void *ptr) {
uint64_t data;
memcpy(&data, ptr, sizeof(data));
return data;
}
void pospopcnt_u16_scalar_umul128 (const uint16_t* in, uint32_t n, uint32_t* out)
{
memset(out, 0, 16*sizeof(uint32_t));
while (n >= 8) {
uint64_t counter_a = 0; // 4 packed 12-bit counters
uint64_t counter_b = 0;
uint64_t counter_c = 0;
uint64_t counter_d = 0;
// end before overflowing the counters
uint32_t len = ((n < 0x0FFF) ? n : 0x0FFF) & ~7;
n -= len;
for (const uint16_t* end = &in[len]; in != end; in += 8) {
const uint64_t mask_a = UINT64_C(0x1111111111111111);
const uint64_t mask_b = mask_a + mask_a;
const uint64_t mask_c = mask_b + mask_b;
const uint64_t mask_0001 = UINT64_C(0x0001000100010001);
const uint64_t mask_cnts = UINT64_C(0x000000F00F00F00F);
uint64_t v0 = loadu_u64(&in[0]);
uint64_t v1 = loadu_u64(&in[4]);
uint64_t a = (v0 & mask_a) + (v1 & mask_a);
uint64_t b = ((v0 & mask_b) + (v1 & mask_b)) >> 1;
uint64_t c = ((v0 & mask_c) + (v1 & mask_c)) >> 2;
uint64_t d = ((v0 >> 3) & mask_a) + ((v1 >> 3) & mask_a);
uint64_t hi;
a = umul128(a, mask_0001, &hi);
a += hi; // broadcast 4-bit counts
b = umul128(b, mask_0001, &hi);
b += hi;
c = umul128(c, mask_0001, &hi);
c += hi;
d = umul128(d, mask_0001, &hi);
d += hi;
counter_a += a & mask_cnts;
counter_b += b & mask_cnts;
counter_c += c & mask_cnts;
counter_d += d & mask_cnts;
}
out[0] += counter_a & 0x0FFF;
out[1] += counter_b & 0x0FFF;
out[2] += counter_c & 0x0FFF;
out[3] += counter_d & 0x0FFF;
out[4] += (counter_a >> 36);
out[5] += (counter_b >> 36);
out[6] += (counter_c >> 36);
out[7] += (counter_d >> 36);
out[8] += (counter_a >> 24) & 0x0FFF;
out[9] += (counter_b >> 24) & 0x0FFF;
out[10] += (counter_c >> 24) & 0x0FFF;
out[11] += (counter_d >> 24) & 0x0FFF;
out[12] += (counter_a >> 12) & 0x0FFF;
out[13] += (counter_b >> 12) & 0x0FFF;
out[14] += (counter_c >> 12) & 0x0FFF;
out[15] += (counter_d >> 12) & 0x0FFF;
}
// assert(n < 8)
if (n != 0) {
uint64_t tail_counter_a = 0;
uint64_t tail_counter_b = 0;
do { // zero-extend a bit to 8-bits (emulate pdep) then accumulate
const uint64_t mask_01 = UINT64_C(0x0101010101010101);
const uint64_t magic = UINT64_C(0x0000040010004001); // 1+(1<<14)+(1<<28)+(1<<42)
uint64_t x = *in++;
tail_counter_a += ((x & 0x5555) * magic) & mask_01; // 0101010101010101
tail_counter_b += (((x >> 1) & 0x5555) * magic) & mask_01;
} while (--n);
out[0] += tail_counter_a & 0xFF;
out[8] += (tail_counter_a >> 8) & 0xFF;
out[2] += (tail_counter_a >> 16) & 0xFF;
out[10] += (tail_counter_a >> 24) & 0xFF;
out[4] += (tail_counter_a >> 32) & 0xFF;
out[12] += (tail_counter_a >> 40) & 0xFF;
out[6] += (tail_counter_a >> 48) & 0xFF;
out[14] += (tail_counter_a >> 56) & 0xFF;
out[1] += tail_counter_b & 0xFF;
out[9] += (tail_counter_b >> 8) & 0xFF;
out[3] += (tail_counter_b >> 16) & 0xFF;
out[11] += (tail_counter_b >> 24) & 0xFF;
out[5] += (tail_counter_b >> 32) & 0xFF;
out[13] += (tail_counter_b >> 40) & 0xFF;
out[7] += (tail_counter_b >> 48) & 0xFF;
out[15] += (tail_counter_b >> 56) & 0xFF;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment