Skip to content

Instantly share code, notes, and snippets.

@bitonic
Created July 1, 2024 20:15
Show Gist options
  • Save bitonic/f60880888b20a8b5d67d53213e11f25e to your computer and use it in GitHub Desktop.
Save bitonic/f60880888b20a8b5d67d53213e11f25e to your computer and use it in GitHub Desktop.
// See companion blog post <https://mazzo.li/posts/crc-tips.html> for
// extended commentary.
//
// Compiled with
//
// % clang++ --version
// clang version 17.0.6
// % clang++ -std=c++20 -march=raptorlake -Wall crc32c-tips.cpp -o crc32c-tips
#include <stdint.h>
#include <stdio.h>
#include <nmmintrin.h>
#include <wmmintrin.h>
#include <vector>
static const uint32_t CASTAGNOLI_POLY = 0x82F63B78u;
uint32_t crc32c_u8(uint32_t crc, uint8_t a) {
crc ^= a;
for (int i = 0; i < 8; i++) {
crc = (crc >> 1) ^ ((crc&1) ? CASTAGNOLI_POLY : 0);
}
return crc;
}
uint32_t crc32c_reference(uint32_t crc, size_t size, char* data) {
crc = ~crc;
for (size_t i = 0; i < size; i++) {
crc = crc32c_u8(crc, data[i]);
}
return ~crc;
}
// See <https://www.corsix.org/content/fast-crc32c-4k>
uint32_t crc32_4k_fusion(uint32_t acc_a, const char* buf, size_t n_blocks) {
size_t stride = n_blocks * 24 + 8;
// Four chunks:
// Chunk A: 0 through stride
// Chunk B: stride through stride*2
// Chunk C: stride*2 through stride*3-8
// Chunk D: stride*3-8 through n_blocks*136+16
// First block of 64 from D is easy.
const char* buf2 = buf + n_blocks * 72 + 16;
__m128i x1 = _mm_loadu_si128((__m128i*)buf2);
__m128i x2 = _mm_loadu_si128((__m128i*)(buf2 + 16));
__m128i x3 = _mm_loadu_si128((__m128i*)(buf2 + 32));
__m128i x4 = _mm_loadu_si128((__m128i*)(buf2 + 48));
uint32_t acc_b = 0;
uint32_t acc_c = 0;
// Parallel fold remaining blocks of 64 from D, and 24 from each of A/B/C.
// k1 == magic(4*128+32-1)
// k2 == magic(4*128-32-1)
__m128i k1k2 = _mm_setr_epi32(/*k1*/ 0x740EEF02, 0, /*k2*/ 0x9E4ADDF8, 0);
const char* end = buf + (n_blocks * 136 + 16) - 64;
while (buf2 < end) {
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)buf);
__m128i x5 = _mm_clmulepi64_si128(x1, k1k2, 0x00);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride));
x1 = _mm_clmulepi64_si128(x1, k1k2, 0x11);
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2));
__m128i x6 = _mm_clmulepi64_si128(x2, k1k2, 0x00);
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)(buf + 8));
x2 = _mm_clmulepi64_si128(x2, k1k2, 0x11);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride + 8));
__m128i x7 = _mm_clmulepi64_si128(x3, k1k2, 0x00);
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2 + 8));
x3 = _mm_clmulepi64_si128(x3, k1k2, 0x11);
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)(buf + 16));
__m128i x8 = _mm_clmulepi64_si128(x4, k1k2, 0x00);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride + 16));
x4 = _mm_clmulepi64_si128(x4, k1k2, 0x11);
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2 + 16));
x5 = _mm_xor_si128(x5, _mm_loadu_si128((__m128i*)(buf2 + 64)));
x1 = _mm_xor_si128(x1, x5);
x6 = _mm_xor_si128(x6, _mm_loadu_si128((__m128i*)(buf2 + 80)));
x2 = _mm_xor_si128(x2, x6);
x7 = _mm_xor_si128(x7, _mm_loadu_si128((__m128i*)(buf2 + 96)));
x3 = _mm_xor_si128(x3, x7);
x8 = _mm_xor_si128(x8, _mm_loadu_si128((__m128i*)(buf2 + 112)));
x4 = _mm_xor_si128(x4, x8);
buf2 += 64;
buf += 24;
}
// Next 24 bytes from A/B/C, and 8 more from A/B, then merge A/B/C.
// Meanwhile, fold together D's four parallel streams.
// k3 == magic(128+32-1)
// k4 == magic(128-32-1)
__m128i k3k4 = _mm_setr_epi32(/*k3*/ 0xF20C0DFE, 0, /*k4*/ 0x493C7D27, 0);
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)buf);
__m128i x5 = _mm_clmulepi64_si128(x1, k3k4, 0x00);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride));
x1 = _mm_clmulepi64_si128(x1, k3k4, 0x11);
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2));
__m128i x6 = _mm_clmulepi64_si128(x3, k3k4, 0x00);
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)(buf + 8));
x3 = _mm_clmulepi64_si128(x3, k3k4, 0x11);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride + 8));
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2 + 8));
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)(buf + 16));
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride + 16));
x5 = _mm_xor_si128(x5, x2);
acc_c = _mm_crc32_u64(acc_c, *(uint64_t*)(buf + stride*2 + 16));
x1 = _mm_xor_si128(x1, x5);
acc_a = _mm_crc32_u64(acc_a, *(uint64_t*)(buf + 24));
// k5 == magic(2*128+32-1)
// k6 == magic(2*128-32-1)
__m128i k5k6 = _mm_setr_epi32(/*k5*/ 0x3DA6D0CB, 0, /*k6*/ 0xBA4FC28E, 0);
x6 = _mm_xor_si128(x6, x4);
x3 = _mm_xor_si128(x3, x6);
x5 = _mm_clmulepi64_si128(x1, k5k6, 0x00);
acc_b = _mm_crc32_u64(acc_b, *(uint64_t*)(buf + stride + 24));
x1 = _mm_clmulepi64_si128(x1, k5k6, 0x11);
// Compute the magic numbers which depend upon n_blocks
// (required for merging A/B/C/D)
uint64_t bits_c = n_blocks*64 - 33;
uint64_t bits_b = bits_c + stride - 8;
uint64_t bits_a = bits_b + stride;
uint64_t stack_a = ~(uint64_t)8;
uint64_t stack_b = stack_a;
uint64_t stack_c = stack_a;
while (bits_a > 191) {
stack_a = (stack_a << 1) + (bits_a & 1); bits_a = (bits_a >> 1) - 16;
stack_b = (stack_b << 1) + (bits_b & 1); bits_b = (bits_b >> 1) - 16;
stack_c = (stack_c << 1) + (bits_c & 1); bits_c = (bits_c >> 1) - 16;
}
stack_a = ~stack_a;
stack_b = ~stack_b;
stack_c = ~stack_c;
uint32_t magic_a = ((uint32_t)0x80000000) >> (bits_a & 31); bits_a >>= 5;
uint32_t magic_b = ((uint32_t)0x80000000) >> (bits_b & 31); bits_b >>= 5;
uint32_t magic_c = ((uint32_t)0x80000000) >> (bits_c & 31); bits_c >>= 5;
bits_a -= bits_b;
bits_b -= bits_c;
for (; bits_c; --bits_c) magic_a = _mm_crc32_u32(magic_a, 0), magic_b = _mm_crc32_u32(magic_b, 0), magic_c = _mm_crc32_u32(magic_c, 0);
for (; bits_b; --bits_b) magic_a = _mm_crc32_u32(magic_a, 0), magic_b = _mm_crc32_u32(magic_b, 0);
for (; bits_a; --bits_a) magic_a = _mm_crc32_u32(magic_a, 0);
for (;;) {
uint32_t low = stack_a & 1;
if (!(stack_a >>= 1)) break;
__m128i x = _mm_cvtsi32_si128(magic_a);
uint64_t y = _mm_cvtsi128_si64(_mm_clmulepi64_si128(x, x, 0));
magic_a = _mm_crc32_u64(0, y << low);
x = _mm_cvtsi32_si128(magic_c);
y = _mm_cvtsi128_si64(_mm_clmulepi64_si128(x, x, 0));
magic_c = _mm_crc32_u64(0, y << (stack_c & 1));
stack_c >>= 1;
x = _mm_cvtsi32_si128(magic_b);
y = _mm_cvtsi128_si64(_mm_clmulepi64_si128(x, x, 0));
magic_b = _mm_crc32_u64(0, y << (stack_b & 1));
stack_b >>= 1;
}
__m128i vec_c = _mm_clmulepi64_si128(_mm_cvtsi32_si128(acc_c), _mm_cvtsi32_si128(magic_c), 0x00);
__m128i vec_a = _mm_clmulepi64_si128(_mm_cvtsi32_si128(acc_a), _mm_cvtsi32_si128(magic_a), 0x00);
__m128i vec_b = _mm_clmulepi64_si128(_mm_cvtsi32_si128(acc_b), _mm_cvtsi32_si128(magic_b), 0x00);
x5 = _mm_xor_si128(x5, x3);
x1 = _mm_xor_si128(x1, x5);
uint64_t abc = _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(vec_c, vec_a), vec_b));
// Apply missing <<32 and fold down to 32-bits.
uint32_t crc = _mm_crc32_u64(0, _mm_extract_epi64(x1, 0));
crc = _mm_crc32_u64(crc, abc ^ _mm_extract_epi64(x1, 1));
return crc;
}
template<bool CLMUL>
uint32_t crc32c(uint32_t crc, size_t size, const char* buf);
template<>
uint32_t crc32c<true>(uint32_t crc, size_t size, const char* buf) {
crc = ~crc;
if (size >= 31) {
size_t n_blocks = (size - 16) / 136;
size_t kernel_length = n_blocks * 136 + 16;
if (kernel_length + (-((uintptr_t)buf + n_blocks * 8) & 15) > size) {
n_blocks -= 1;
kernel_length -= 136;
}
const char* kernel_end = (const char*)((uintptr_t)(buf + kernel_length + 15) & ~(uintptr_t)15);
const char* kernel_start = kernel_end - kernel_length;
size -= kernel_start - buf;
for (; buf != kernel_start; ++buf) {
crc = _mm_crc32_u8(crc, *(const uint8_t*)buf);
}
if (n_blocks) {
size -= kernel_length;
crc = crc32_4k_fusion(crc, buf, n_blocks);
buf = kernel_end;
}
}
for (; size >= 8; size -= 8, buf += 8) {
crc = _mm_crc32_u64(crc, *(const uint64_t*)buf);
}
for (; size; --size, ++buf) {
crc = _mm_crc32_u8(crc, *(const uint8_t*)buf);
}
return ~crc;
}
template<>
uint32_t crc32c<false>(uint32_t crc, size_t size, const char* buf) {
crc = ~crc;
const char* end = buf+size;
const char* unaligned_end = (char*)(((uintptr_t)buf - 1 + 8) & ~7u);
for (; buf < unaligned_end; buf++) {
crc = _mm_crc32_u8(crc, *(const uint8_t*)buf);
}
for (; buf+8 <= end; buf += 8) {
crc = _mm_crc32_u64(crc, *(const uint64_t*)buf);
}
for (; buf < end; buf++) {
crc = _mm_crc32_u8(crc, *(const uint8_t*)buf);
}
return ~crc;
}
template<bool CLMUL>
uint64_t crc32c_mul(uint32_t a, uint32_t b);
template<>
uint64_t crc32c_mul<false>(uint32_t a, uint32_t b32) {
uint64_t b = (uint64_t)b32 << 32;
uint64_t c = 0;
for (int i = 0; i < 32; i++, a <<= 1, b >>= 1) {
c ^= (a & (1u<<31)) ? b : 0;
}
return c;
}
template<>
uint64_t crc32c_mul<true>(uint32_t a, uint32_t b) {
uint64_t c = _mm_cvtsi128_si64(
_mm_clmulepi64_si128(_mm_set_epi32(0, 0, 0, a), _mm_set_epi32(0, 0, 0, b), 0)
);
return c << 1; // unused bit
}
// `a mod CASTAGNOLI_POLY`
uint32_t crc32c_mod_p(uint64_t a) {
for (int i = 0; i < 32; i++) {
a = (a >> 1) ^ ((a&1) ? CASTAGNOLI_POLY : 0);
}
return a;
}
// crc32c_mul_mod_p<false> could also be implemented as a single loop.
template<bool CLMUL>
uint32_t crc32c_mul_mod_p(uint32_t a, uint32_t b) {
return crc32c_mod_p(crc32c_mul<CLMUL>(a, b));
}
void crc32c_generate_power_table() {
printf("// x^2^3, x^2^4, ..., x^2^63\n");
printf("static uint32_t CRC_POWER_TABLE[64] = {");
uint32_t p = 1u << 30; // start with x
for (int i = 0; i < 67; i++) {
if ((i-3) % 8 == 0) {
printf("\n ");
}
if (i > 2) {
printf("0x%08x, ", p);
}
p = crc32c_mul_mod_p<true>(p, p);
}
printf("\n};\n\n");
}
// x^2^3, x^2^4, ..., x^2^63
static uint32_t CRC_POWER_TABLE[64] = {
0x00800000, 0x00008000, 0x82f63b78, 0x6ea2d55c, 0x18b8ea18, 0x510ac59a, 0xb82be955, 0xb8fdb1e7,
0x88e56f72, 0x74c360a4, 0xe4172b16, 0x0d65762a, 0x35d73a62, 0x28461564, 0xbf455269, 0xe2ea32dc,
0xfe7740e6, 0xf946610b, 0x3c204f8f, 0x538586e3, 0x59726915, 0x734d5309, 0xbc1ac763, 0x7d0722cc,
0xd289cabe, 0xe94ca9bc, 0x05b74f3f, 0xa51e1f42, 0x40000000, 0x20000000, 0x08000000, 0x00800000,
0x00008000, 0x82f63b78, 0x6ea2d55c, 0x18b8ea18, 0x510ac59a, 0xb82be955, 0xb8fdb1e7, 0x88e56f72,
0x74c360a4, 0xe4172b16, 0x0d65762a, 0x35d73a62, 0x28461564, 0xbf455269, 0xe2ea32dc, 0xfe7740e6,
0xf946610b, 0x3c204f8f, 0x538586e3, 0x59726915, 0x734d5309, 0xbc1ac763, 0x7d0722cc, 0xd289cabe,
0xe94ca9bc, 0x05b74f3f, 0xa51e1f42, 0x40000000, 0x20000000, 0x08000000, 0x00800000, 0x00008000,
};
template<bool CLMUL>
uint32_t crc32c_x_pow_n(size_t n) {
uint32_t x_pow_n = 1u << 31;
for (int k = 0; n; k++, n >>= 1) {
if (n&1) {
x_pow_n = crc32c_mul_mod_p<CLMUL>(x_pow_n, CRC_POWER_TABLE[k]);
}
}
return x_pow_n;
}
template<bool CLMUL>
uint32_t crc32c_add_zeros(uint32_t crc, size_t size) {
return ~crc32c_mul_mod_p<CLMUL>(~crc, crc32c_x_pow_n<CLMUL>(size));
}
void crc32c_generate_inverse_power_table() {
// start with x^-1
uint32_t p;
for (uint64_t i = 0; i < 1ull<<32; i++) {
p = (uint32_t)i;
if (crc32c_mul_mod_p<true>(1u<<30, p) == 1u<<31) {
break;
}
}
printf("// x^-(2^3), x^-(2^4), ..., x^-(2^63)\n");
printf("static uint32_t CRC_INVERSE_POWER_TABLE[64] = {");
for (int i = 0; i < 67; i++) {
if ((i-3) % 8 == 0) {
printf("\n ");
}
if (i > 2) {
printf("0x%08x, ", p);
}
p = crc32c_mul_mod_p<true>(p, p);
}
printf("\n};\n\n");
}
// x^-(2^3), x^-(2^4), ..., x^-(2^63)
static uint32_t CRC_INVERSE_POWER_TABLE[64] = {
0xfde39562, 0xbef0965e, 0xd610d67e, 0xe67cce65, 0xa268b79e, 0x134fb088, 0x32998d96, 0xcedac2cc,
0x70118575, 0x0e004a40, 0xa7864c8b, 0xbc7be916, 0x10ba2894, 0x6077197b, 0x98448e4e, 0x8baf845d,
0xe93e07fc, 0xf58027d7, 0x5e2b422d, 0x9db2851c, 0x9270ed25, 0x5984e7b3, 0x7af026f1, 0xe0f4116b,
0xace8a6b0, 0x9e09f006, 0x6a60ea71, 0x4fd04875, 0x05ec76f1, 0x0bd8ede2, 0x2f63b788, 0xfde39562,
0xbef0965e, 0xd610d67e, 0xe67cce65, 0xa268b79e, 0x134fb088, 0x32998d96, 0xcedac2cc, 0x70118575,
0x0e004a40, 0xa7864c8b, 0xbc7be916, 0x10ba2894, 0x6077197b, 0x98448e4e, 0x8baf845d, 0xe93e07fc,
0xf58027d7, 0x5e2b422d, 0x9db2851c, 0x9270ed25, 0x5984e7b3, 0x7af026f1, 0xe0f4116b, 0xace8a6b0,
0x9e09f006, 0x6a60ea71, 0x4fd04875, 0x05ec76f1, 0x0bd8ede2, 0x2f63b788, 0xfde39562, 0xbef0965e,
};
template<bool CLMUL>
uint32_t crc32c_x_pow_neg_n(size_t n) {
uint32_t x_pow_n = 1u << 31;
for (int k = 0; n; k++, n >>= 1) {
if (n&1) {
x_pow_n = crc32c_mul_mod_p<CLMUL>(x_pow_n, CRC_INVERSE_POWER_TABLE[k]);
}
}
return x_pow_n;
}
template<bool CLMUL>
uint32_t crc32c_remove_zeros(uint32_t crc, size_t size) {
return ~crc32c_mul_mod_p<CLMUL>(~crc, crc32c_x_pow_neg_n<CLMUL>(size));
}
template<bool CLMUL>
uint32_t crc32c_xor(size_t size, uint32_t crc_a, uint32_t crc_b) {
return crc_a ^ crc_b ^ crc32c_add_zeros<CLMUL>(0, size);
}
template<bool CLMUL>
uint32_t crc32c_append(uint32_t crc_a, size_t size_b, uint32_t crc_b) {
return crc32c_mul_mod_p<CLMUL>(crc_a, crc32c_x_pow_n<CLMUL>(size_b)) ^ crc_b;
}
// Tests
static inline uint64_t wyhash64(uint64_t& state) {
state += UINT64_C(0x60bee2bee120fc15);
__uint128_t tmp;
tmp = (__uint128_t)state * UINT64_C(0xa3b195354a39b70d);
uint64_t m1 = (tmp >> 64) ^ tmp;
tmp = (__uint128_t)m1 * UINT64_C(0x1b03738712fad5c9);
uint64_t m2 = (tmp >> 64) ^ tmp;
return m2;
}
static void wyhash64_bytes(uint64_t& state, size_t size, char* bytes) {
char* end = bytes+size;
char* unaligned_end = (char*)(((uintptr_t)bytes - 1 + 8) & ~7u);
for (; bytes < unaligned_end; bytes++) {
*bytes = wyhash64(state) & 0xFF;
}
uint64_t* words = (uint64_t*)bytes;
for (; (char*)(words + 1) <= end; words++) {
*words = wyhash64(state);
}
for (bytes = (char*)words; bytes < end; bytes++) {
*bytes = wyhash64(state) & 0xFF;
}
}
#define CHECK(__b) do { \
if (!(__b)) { \
printf("%s:%d assertion failure: " #__b "\n", __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
int main(void) {
crc32c_generate_power_table();
crc32c_generate_inverse_power_table();
uint64_t seed = 0;
// Test CLMUL vs. non-CLMUL vs. reference
for (int i = 0; i < 1000; i++) {
size_t len = wyhash64(seed)%(4096*10);
std::vector<char> buf(len);
wyhash64_bytes(seed, len, buf.data());
CHECK(crc32c<true>(0, len, buf.data()) == crc32c<false>(0, len, buf.data()));
CHECK(crc32c<true>(0, len, buf.data()) == crc32c_reference(0, len, buf.data()));
}
for (int i = 0; i < 1000; i++) {
uint32_t a = wyhash64(seed);
uint32_t b = wyhash64(seed);
CHECK(crc32c_mul<true>(a, b) == crc32c_mul<false>(a, b));
}
for (int i = 0; i < 1000; i++) {
uint32_t a = wyhash64(seed);
uint32_t b = wyhash64(seed);
CHECK(crc32c_mul_mod_p<true>(a, b) == crc32c_mul_mod_p<false>(a, b));
}
// crc32c_add_zeros
for (int i = 0; i < 1000; i++) {
size_t len1 = wyhash64(seed)%(4096*10);
size_t len2 = wyhash64(seed)%(4096*10);
std::vector<char> buf(len1);
wyhash64_bytes(seed, len1, buf.data());
std::vector<char> zeros(len2, 0);
uint32_t buf_crc = crc32c<true>(0, len1, buf.data());
CHECK(crc32c<true>(buf_crc, len2, zeros.data()) == crc32c_add_zeros<true>(buf_crc, len2));
}
// crc32c_remove_zeros
for (int i = 0; i < 1000; i++) {
size_t len1 = wyhash64(seed)%(4096*10);
size_t len2 = wyhash64(seed)%(4096*10);
std::vector<char> buf(len1);
wyhash64_bytes(seed, len1, buf.data());
std::vector<char> zeros(len2, 0);
uint32_t buf_crc = crc32c<true>(0, len1, buf.data());
uint32_t crc = crc32c<true>(buf_crc, len2, zeros.data());
CHECK(buf_crc == crc32c_remove_zeros<true>(crc, len2));
}
// crc32c_xor
for (int i = 0; i < 1000; i++) {
size_t len = wyhash64(seed)%(4096*10);
std::vector<char> buf1(len);
wyhash64_bytes(seed, len, buf1.data());
std::vector<char> buf2(len);
wyhash64_bytes(seed, len, buf2.data());
std::vector<char> buf3(len);
for (size_t j = 0; j < len; j++) {
buf3[j] = buf1[j] ^ buf2[j];
}
CHECK(crc32c<true>(0, len, buf3.data()) == crc32c_xor<true>(len, crc32c<true>(0, len, buf1.data()), crc32c<true>(0, len, buf2.data())));
}
// crc32c_append
for (int i = 0; i < 1000; i++) {
size_t len1 = wyhash64(seed)%(4096*10);
size_t len2 = wyhash64(seed)%(4096*10);
std::vector<char> buf1(len1);
wyhash64_bytes(seed, len1, buf1.data());
std::vector<char> buf2(len2);
wyhash64_bytes(seed, len2, buf2.data());
std::vector<char> buf3 = buf1;
buf3.insert(buf3.end(), buf2.begin(), buf2.end());
CHECK(crc32c<true>(0, len1+len2, buf3.data()) == crc32c_append<true>(crc32c<true>(0, len1, buf1.data()), len2, crc32c<true>(0, len2, buf2.data())));
}
printf("All tests pass!\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment