Skip to content

Instantly share code, notes, and snippets.

@mmozeiko
Created November 14, 2018 07:34
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mmozeiko/43227c597150ecc0b51be9887d52cf34 to your computer and use it in GitHub Desktop.
Save mmozeiko/43227c597150ecc0b51be9887d52cf34 to your computer and use it in GitHub Desktop.
Go's AES-NI based hash
#include <stddef.h>
#include <stdint.h>
#include <intrin.h>
// see aeshashbody in https://github.com/golang/go/blob/master/src/runtime/asm_amd64.s
// this is initialized on process startup with random from system
static __declspec(align(16)) uint8_t aeskeysched[128];
static __declspec(align(16)) const uint8_t masks[16][16] =
{
{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00 },
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00 },
};
static __declspec(align(16)) const uint8_t shifts[16][16] =
{
{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff },
{ 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff },
{ 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff },
{ 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff },
{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff },
};
static uint64_t go_aeshash(uint64_t seed, const void* data, size_t size)
{
const uint8_t* bytes = data;
__m128i x0 = _mm_cvtsi64_si128(seed);
x0 = _mm_insert_epi16(x0, (uint16_t)size, 4);
x0 = _mm_shufflehi_epi16(x0, 0);
__m128i x1 = x0;
x0 = _mm_xor_si128(x0, ((__m128i*)&aeskeysched)[0]);
x0 = _mm_aesenc_si128(x0, x0);
if (size == 0)
{
x0 = _mm_aesenc_si128(x0, x0);
return _mm_cvtsi128_si64(x0);
}
else if (size < 16)
{
if (((uintptr_t)data % 4096) <= 4096 - 16)
{
x1 = _mm_loadu_si128((__m128i*)bytes);
x1 = _mm_and_si128(x1, ((__m128i*)masks)[size]);
}
else
{
x1 = _mm_loadu_si128((__m128i*)(bytes - 16 + size));
x1 = _mm_shuffle_epi8(x1, ((__m128i*)shifts)[size]);
}
x1 = _mm_xor_si128(x1, x0);
x1 = _mm_aesenc_si128(x1, x1);
x1 = _mm_aesenc_si128(x1, x1);
x1 = _mm_aesenc_si128(x1, x1);
return _mm_cvtsi128_si64(x1);
}
else if (size == 16)
{
x1 = _mm_loadu_si128((__m128i*)bytes);
x1 = _mm_xor_si128(x1, x0);
x1 = _mm_aesenc_si128(x1, x1);
x1 = _mm_aesenc_si128(x1, x1);
x1 = _mm_aesenc_si128(x1, x1);
return _mm_cvtsi128_si64(x1);
}
else if (size <= 32)
{
x1 = _mm_xor_si128(x1, ((__m128i*)&aeskeysched)[1]);
x1 = _mm_aesenc_si128(x1, x1);
__m128i x2 = _mm_loadu_si128((__m128i*)bytes);
__m128i x3 = _mm_loadu_si128((__m128i*)(bytes + size - 16));
x2 = _mm_xor_si128(x2, x0);
x3 = _mm_xor_si128(x3, x1);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
x2 = _mm_xor_si128(x2, x3);
return _mm_cvtsi128_si64(x2);
}
else if (size <= 64)
{
__m128i x2 = x1;
__m128i x3 = x1;
x1 = _mm_xor_si128(x1, ((__m128i*)&aeskeysched)[1]);
x2 = _mm_xor_si128(x2, ((__m128i*)&aeskeysched)[2]);
x3 = _mm_xor_si128(x3, ((__m128i*)&aeskeysched)[3]);
x1 = _mm_aesenc_si128(x1, x1);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
__m128i x4 = _mm_loadu_si128((__m128i*)bytes);
__m128i x5 = _mm_loadu_si128((__m128i*)(bytes + 16));
__m128i x6 = _mm_loadu_si128((__m128i*)(bytes + size - 32));
__m128i x7 = _mm_loadu_si128((__m128i*)(bytes + size - 16));
x4 = _mm_xor_si128(x4, x0);
x5 = _mm_xor_si128(x5, x1);
x6 = _mm_xor_si128(x6, x2);
x7 = _mm_xor_si128(x7, x3);
x4 = _mm_aesenc_si128(x4, x4);
x5 = _mm_aesenc_si128(x5, x5);
x6 = _mm_aesenc_si128(x6, x6);
x7 = _mm_aesenc_si128(x7, x7);
x4 = _mm_aesenc_si128(x4, x4);
x5 = _mm_aesenc_si128(x5, x5);
x6 = _mm_aesenc_si128(x6, x6);
x7 = _mm_aesenc_si128(x7, x7);
x4 = _mm_aesenc_si128(x4, x4);
x5 = _mm_aesenc_si128(x5, x5);
x6 = _mm_aesenc_si128(x6, x6);
x7 = _mm_aesenc_si128(x7, x7);
x4 = _mm_xor_si128(x4, x6);
x5 = _mm_xor_si128(x5, x7);
x4 = _mm_xor_si128(x4, x5);
return _mm_cvtsi128_si64(x4);
}
else if (size <= 128)
{
__m128i x2 = x1;
__m128i x3 = x1;
__m128i x4 = x1;
__m128i x5 = x1;
__m128i x6 = x1;
__m128i x7 = x1;
x1 = _mm_xor_si128(x1, ((__m128i*)&aeskeysched)[1]);
x2 = _mm_xor_si128(x2, ((__m128i*)&aeskeysched)[2]);
x3 = _mm_xor_si128(x3, ((__m128i*)&aeskeysched)[3]);
x4 = _mm_xor_si128(x4, ((__m128i*)&aeskeysched)[4]);
x5 = _mm_xor_si128(x5, ((__m128i*)&aeskeysched)[5]);
x6 = _mm_xor_si128(x6, ((__m128i*)&aeskeysched)[6]);
x7 = _mm_xor_si128(x7, ((__m128i*)&aeskeysched)[7]);
x1 = _mm_aesenc_si128(x1, x1);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
x4 = _mm_aesenc_si128(x4, x4);
x5 = _mm_aesenc_si128(x5, x5);
x6 = _mm_aesenc_si128(x6, x6);
x7 = _mm_aesenc_si128(x7, x7);
__m128i x8 = _mm_loadu_si128((__m128i*)bytes);
__m128i x9 = _mm_loadu_si128((__m128i*)(bytes + 16));
__m128i x10 = _mm_loadu_si128((__m128i*)(bytes + 32));
__m128i x11 = _mm_loadu_si128((__m128i*)(bytes + 48));
__m128i x12 = _mm_loadu_si128((__m128i*)(bytes + size - 64));
__m128i x13 = _mm_loadu_si128((__m128i*)(bytes + size - 48));
__m128i x14 = _mm_loadu_si128((__m128i*)(bytes + size - 32));
__m128i x15 = _mm_loadu_si128((__m128i*)(bytes + size - 16));
x8 = _mm_xor_si128(x8, x0);
x9 = _mm_xor_si128(x9, x1);
x10 = _mm_xor_si128(x10, x2);
x11 = _mm_xor_si128(x11, x3);
x12 = _mm_xor_si128(x12, x4);
x13 = _mm_xor_si128(x13, x5);
x14 = _mm_xor_si128(x14, x6);
x15 = _mm_xor_si128(x15, x7);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_xor_si128(x8, x12);
x9 = _mm_xor_si128(x9, x13);
x10 = _mm_xor_si128(x10, x14);
x11 = _mm_xor_si128(x11, x15);
x8 = _mm_xor_si128(x8, x10);
x9 = _mm_xor_si128(x9, x11);
x8 = _mm_xor_si128(x8, x9);
return _mm_cvtsi128_si64(x8);
}
else // size > 128
{
__m128i x2 = x1;
__m128i x3 = x1;
__m128i x4 = x1;
__m128i x5 = x1;
__m128i x6 = x1;
__m128i x7 = x1;
x1 = _mm_xor_si128(x1, ((__m128i*)&aeskeysched)[1]);
x2 = _mm_xor_si128(x2, ((__m128i*)&aeskeysched)[2]);
x3 = _mm_xor_si128(x3, ((__m128i*)&aeskeysched)[3]);
x4 = _mm_xor_si128(x4, ((__m128i*)&aeskeysched)[4]);
x5 = _mm_xor_si128(x5, ((__m128i*)&aeskeysched)[5]);
x6 = _mm_xor_si128(x6, ((__m128i*)&aeskeysched)[6]);
x7 = _mm_xor_si128(x7, ((__m128i*)&aeskeysched)[7]);
x1 = _mm_aesenc_si128(x1, x1);
x2 = _mm_aesenc_si128(x2, x2);
x3 = _mm_aesenc_si128(x3, x3);
x4 = _mm_aesenc_si128(x4, x4);
x5 = _mm_aesenc_si128(x5, x5);
x6 = _mm_aesenc_si128(x6, x6);
x7 = _mm_aesenc_si128(x7, x7);
__m128i x8 = _mm_loadu_si128((__m128i*)(bytes + size - 128));
__m128i x9 = _mm_loadu_si128((__m128i*)(bytes + size - 112));
__m128i x10 = _mm_loadu_si128((__m128i*)(bytes + size - 96));
__m128i x11 = _mm_loadu_si128((__m128i*)(bytes + size - 80));
__m128i x12 = _mm_loadu_si128((__m128i*)(bytes + size - 64));
__m128i x13 = _mm_loadu_si128((__m128i*)(bytes + size - 48));
__m128i x14 = _mm_loadu_si128((__m128i*)(bytes + size - 32));
__m128i x15 = _mm_loadu_si128((__m128i*)(bytes + size - 16));
x8 = _mm_xor_si128(x8, x0);
x9 = _mm_xor_si128(x9, x1);
x10 = _mm_xor_si128(x10, x2);
x11 = _mm_xor_si128(x11, x3);
x12 = _mm_xor_si128(x12, x4);
x13 = _mm_xor_si128(x13, x5);
x14 = _mm_xor_si128(x14, x6);
x15 = _mm_xor_si128(x15, x7);
size = (size - 1) / 128;
do
{
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_aesenc_si128(x8, _mm_loadu_si128((__m128i*)(bytes + 0)));
x9 = _mm_aesenc_si128(x9, _mm_loadu_si128((__m128i*)(bytes + 16)));
x10 = _mm_aesenc_si128(x10, _mm_loadu_si128((__m128i*)(bytes + 32)));
x11 = _mm_aesenc_si128(x11, _mm_loadu_si128((__m128i*)(bytes + 48)));
x12 = _mm_aesenc_si128(x12, _mm_loadu_si128((__m128i*)(bytes + 64)));
x13 = _mm_aesenc_si128(x13, _mm_loadu_si128((__m128i*)(bytes + 80)));
x14 = _mm_aesenc_si128(x14, _mm_loadu_si128((__m128i*)(bytes + 96)));
x15 = _mm_aesenc_si128(x15, _mm_loadu_si128((__m128i*)(bytes + 112)));
bytes += 128;
}
while (--size > 0);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_aesenc_si128(x8, x8);
x9 = _mm_aesenc_si128(x9, x9);
x10 = _mm_aesenc_si128(x10, x10);
x11 = _mm_aesenc_si128(x11, x11);
x12 = _mm_aesenc_si128(x12, x12);
x13 = _mm_aesenc_si128(x13, x13);
x14 = _mm_aesenc_si128(x14, x14);
x15 = _mm_aesenc_si128(x15, x15);
x8 = _mm_xor_si128(x8, x12);
x9 = _mm_xor_si128(x9, x13);
x10 = _mm_xor_si128(x10, x14);
x11 = _mm_xor_si128(x11, x15);
x8 = _mm_xor_si128(x8, x10);
x9 = _mm_xor_si128(x9, x11);
x8 = _mm_xor_si128(x8, x9);
return _mm_cvtsi128_si64(x8);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment