Skip to content

Instantly share code, notes, and snippets.

@CAFxX
Created November 11, 2022 06:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CAFxX/403d7393ab4df4ff25eacc6e018a3aba to your computer and use it in GitHub Desktop.
Save CAFxX/403d7393ab4df4ff25eacc6e018a3aba to your computer and use it in GitHub Desktop.
AVX-512 vectorized associative lookup
#include <immintrin.h>
#include <stdint.h>
int getIndexOf(__m512i const *values, int64_t target)
{
__m512i valuesSimd = _mm512_loadu_si512(values);
__m512i targetSplatted = _mm512_set1_epi64(target);
__mmask8 equalMask = _mm512_cmpeq_epi64_mask(valuesSimd, targetSplatted);
uint32_t equalMaskInt = _cvtmask8_u32(equalMask);
int index = _tzcnt_u32(equalMaskInt);
return index < 8 ? index : -1;
}
int getIndexOf(__m512i const *values, int32_t target)
{
__m512i valuesSimd = _mm512_loadu_si512(values);
__m512i targetSplatted = _mm512_set1_epi32(target);
__mmask16 equalMask = _mm512_cmpeq_epi32_mask(valuesSimd, targetSplatted);
uint32_t equalMaskInt = _cvtmask16_u32(equalMask);
int index = _tzcnt_u32(equalMaskInt);
return index < 16 ? index : -1;
}
int getIndexOf(__m512i const *values, int16_t target)
{
__m512i valuesSimd = _mm512_loadu_si512(values);
__m512i targetSplatted = _mm512_set1_epi16(target);
__mmask32 equalMask = _mm512_cmpeq_epi16_mask(valuesSimd, targetSplatted);
uint32_t equalMaskInt = _cvtmask32_u32(equalMask);
int index = _tzcnt_u32(equalMaskInt);
return index < 32 ? index : -1;
}
int getIndexOf(__m512i const *values, int8_t target)
{
__m512i valuesSimd = _mm512_loadu_si512(values);
__m512i targetSplatted = _mm512_set1_epi8(target);
__mmask64 equalMask = _mm512_cmpeq_epi8_mask(valuesSimd, targetSplatted);
uint64_t equalMaskInt = _cvtmask64_u64(equalMask);
int index = _tzcnt_u64(equalMaskInt);
return index < 64 ? index : -1;
}
int getIndexOf(__m512i const *values, int n, int64_t target) {
for (int i=0; i<n; i++) {
int idx = getIndexOf(values+i, target);
if (idx != -1)
return idx+i*(512/64);
}
return -1;
}
int getIndexOf(__m512i const *values, int n, int32_t target) {
for (int i=0; i<n; i++) {
int idx = getIndexOf(values+i, target);
if (idx != -1)
return idx+i*(512/32);
}
return -1;
}
int getIndexOf(__m512i const *values, int n, int16_t target) {
for (int i=0; i<n; i++) {
int idx = getIndexOf(values+i, target);
if (idx != -1)
return idx+i*(512/16);
}
return -1;
}
int getIndexOf(__m512i const *values, int n, int8_t target) {
for (int i=0; i<n; i++) {
int idx = getIndexOf(values+i, target);
if (idx != -1)
return idx+i*(512/8);
}
return -1;
}
@CAFxX
Copy link
Author

CAFxX commented Nov 11, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment