Skip to content

Instantly share code, notes, and snippets.

@matsui528
Last active October 23, 2023 05:17
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save matsui528/583925f88fcb08240319030202588c74 to your computer and use it in GitHub Desktop.
Save matsui528/583925f88fcb08240319030202588c74 to your computer and use it in GitHub Desktop.
Runtime evaluation for squared Euclidean distances with SSE, AVX, AVX512 implementations
#include <iostream>
#include <random>
#include <chrono>
#include <x86intrin.h>
#include <cassert>
// Runtime evaluation for squared Eucliden distance functions
// - fvec_L2_sqr_ref: naive reference impl from Faiss
// - fvec_L2_sqr_sse: SSE impl from Faiss
// - fvec_L2_sqr_avx: AVX impl from Faiss
// - fvec_L2_sqr_avx512: AVX512 impl
// Note that fvec_L2_sqr_{ref, sse, avx} are from Faiss:
// https://github.com/facebookresearch/faiss/blob/master/utils.cpp
// Compile:
// $ g++ -O3 -Wall --std=c++14 -march=native -o main main.cpp
// Result on c5.2xlarge instance on AWS EC2:
// ref: 1345 msec
// sse: 371 msec
// avx: 262 msec
// avx512: 255 msec
const uint uint32_max = 4294967295;
std::mt19937 mt(123);
std::vector<std::vector<float>> gen_random_fvectors(int N, int D) {
std::vector<std::vector<float>> vecs(N, std::vector<float>(D));
for (int n = 0; n < N; ++n) {
for(int d = 0; d < D; ++d) {
vecs[n][d] = (float) mt() / uint32_max;
}
}
return vecs;
}
float fvec_L2sqr_ref(const float *x, const float *y, size_t d)
{
size_t i;
float res_ = 0;
for (i = 0; i < d; i++) {
const float tmp = x[i] - y[i];
res_ += tmp * tmp;
}
return res_;
}
// reads 0 <= d < 4 floats as __m128
static inline __m128 masked_read (int d, const float *x)
{
assert (0 <= d && d < 4);
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
switch (d) {
case 3:
buf[2] = x[2];
case 2:
buf[1] = x[1];
case 1:
buf[0] = x[0];
}
return _mm_load_ps (buf);
// cannot use AVX2 _mm_mask_set1_epi32
}
// SSE implementation
float fvec_L2sqr_sse (const float *x, const float *y, size_t d)
{
__m128 msum1 = _mm_setzero_ps();
while (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
d -= 4;
}
if (d > 0) {
// add the last 1, 2 or 3 values
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
}
msum1 = _mm_hadd_ps (msum1, msum1);
msum1 = _mm_hadd_ps (msum1, msum1);
return _mm_cvtss_f32 (msum1);
}
// reads 0 <= d < 8 floats as __m256
static inline __m256 masked_read_8 (int d, const float *x)
{
assert (0 <= d && d < 8);
if (d < 4) {
__m256 res = _mm256_setzero_ps ();
res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
return res;
} else {
__m256 res = _mm256_setzero_ps ();
res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
return res;
}
}
float fvec_L2sqr_avx (const float *x, const float *y, size_t d)
{
__m256 msum1 = _mm256_setzero_ps();
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
d -= 8;
}
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
d -= 4;
}
if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
}
msum2 = _mm_hadd_ps (msum2, msum2);
msum2 = _mm_hadd_ps (msum2, msum2);
return _mm_cvtss_f32 (msum2);
}
#ifdef __AVX512F__
// reads 0 <= d < 16 floats as __m512
static inline __m512 masked_read_16 (int d, const float *x)
{
assert (0 <= d && d < 16);
if (d < 8) {
__m512 res = _mm512_setzero_ps ();
res = _mm512_insertf32x8 (res, masked_read_8 (d, x), 0);
return res;
} else {
__m512 res = _mm512_setzero_ps ();
res = _mm512_insertf32x8 (res, _mm256_loadu_ps (x), 0);
res = _mm512_insertf32x8 (res, masked_read_8 (d - 8, x + 8), 1);
return res;
}
}
float fvec_L2sqr_avx512 (const float *x, const float *y, size_t d)
{
__m512 msum1 = _mm512_setzero_ps();
while (d >= 16) {
__m512 mx = _mm512_loadu_ps (x); x += 16;
__m512 my = _mm512_loadu_ps (y); y += 16;
const __m512 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
d -= 16;
}
__m256 msum2 = _mm512_extractf32x8_ps(msum1, 1);
msum2 += _mm512_extractf32x8_ps(msum1, 0);
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
d -= 8;
}
__m128 msum3 = _mm256_extractf128_ps(msum2, 1);
msum3 += _mm256_extractf128_ps(msum2, 0);
if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum3 += a_m_b1 * a_m_b1;
d -= 4;
}
if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum3 += a_m_b1 * a_m_b1;
}
msum3 = _mm_hadd_ps (msum3, msum3);
msum3 = _mm_hadd_ps (msum3, msum3);
return _mm_cvtss_f32 (msum3);
}
#endif
int main() {
size_t N = 10000;
size_t Nq = 1000;
size_t D = 128;
std::vector<std::vector<float>> queries = gen_random_fvectors(Nq, D);
std::vector<std::vector<float>> vecs = gen_random_fvectors(N, D);
// ============== Value check ===============
for (const auto &q : queries) {
for (size_t n = 0; n < N; ++n) {
float d1 = fvec_L2sqr_ref(q.data(), vecs[n].data(), D);
float d2 = fvec_L2sqr_sse(q.data(), vecs[n].data(), D);
float d3 = fvec_L2sqr_avx(q.data(), vecs[n].data(), D);
assert(std::abs(d1 - d2) < 0.001);
assert(std::abs(d1 - d3) < 0.001);
#ifdef __AVX512F__
float d4 = fvec_L2sqr_avx512(q.data(), vecs[n].data(), D);
assert(std::abs(d1 - d4) < 0.001);
#endif
}
}
// ===================== Runtime evaluation ======================
auto eval = [&](std::string method, auto fvec_L2sqr){
std::vector<float> rets(N);
std::chrono::system_clock::time_point start, end;
start = std::chrono::system_clock::now();
for (const auto &q : queries) {
for (size_t n = 0; n < N; ++n) {
rets[n] = fvec_L2sqr(q.data(), vecs[n].data(), D);
}
}
end = std::chrono::system_clock::now();
double elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
std::cout << method << ": " << elapsed << " msec" << std::endl;
};
eval("ref", fvec_L2sqr_ref);
eval("sse", fvec_L2sqr_sse);
eval("avx", fvec_L2sqr_avx);
#ifdef __AVX512F__
eval("avx512", fvec_L2sqr_avx512);
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment