Last active
August 17, 2018 16:16
-
-
Save mdouze/a384a01d0e205bee6d39d52170fb3588 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/IndexIVFPQ.cpp b/IndexIVFPQ.cpp | |
--- a/IndexIVFPQ.cpp | |
+++ b/IndexIVFPQ.cpp | |
@@ -414,11 +414,15 @@ | |
namespace { | |
static uint64_t get_cycles () { | |
+#ifdef __x86_64__ | |
uint32_t high, low; | |
asm volatile("rdtsc \n\t" | |
: "=a" (low), | |
"=d" (high)); | |
return ((uint64_t)high << 32) | (low); | |
+#else | |
+ return 0; | |
+#endif | |
} | |
#define TIC t0 = get_cycles() | |
diff --git a/IndexScalarQuantizer.cpp b/IndexScalarQuantizer.cpp | |
--- a/IndexScalarQuantizer.cpp | |
+++ b/IndexScalarQuantizer.cpp | |
@@ -8,7 +8,9 @@ | |
#include <omp.h> | |
+#ifdef __SSE__ | |
#include <immintrin.h> | |
+#endif | |
#include "utils.h" | |
diff --git a/utils.cpp b/utils.cpp | |
--- a/utils.cpp | |
+++ b/utils.cpp | |
@@ -8,8 +8,9 @@ | |
#include <cstring> | |
#include <cmath> | |
+#ifdef __SSE__ | |
#include <immintrin.h> | |
- | |
+#endif | |
#include <sys/time.h> | |
#include <sys/types.h> | |
@@ -412,8 +413,6 @@ | |
* Reference implementations | |
*/ | |
- | |
- | |
/* same without SSE */ | |
float fvec_L2sqr_ref (const float * x, | |
const float * y, | |
@@ -439,8 +438,7 @@ | |
return res; | |
} | |
-float fvec_norm_L2sqr_ref (const float * __restrict x, | |
- size_t d) | |
+float fvec_norm_L2sqr_ref (const float *x, size_t d) | |
{ | |
size_t i; | |
double res = 0; | |
@@ -454,6 +452,8 @@ | |
* SSE and AVX implementations | |
*/ | |
+#ifdef __SSE__ | |
+ | |
// reads 0 <= d < 4 floats as __m128 | |
static inline __m128 masked_read (int d, const float *x) | |
{ | |
@@ -471,6 +471,28 @@ | |
// cannot use AVX2 _mm_mask_set1_epi32 | |
} | |
+float fvec_norm_L2sqr (const float * x, | |
+ size_t d) | |
+{ | |
+ __m128 mx; | |
+ __m128 msum1 = _mm_setzero_ps(); | |
+ | |
+ while (d >= 4) { | |
+ mx = _mm_loadu_ps (x); x += 4; | |
+ msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); | |
+ d -= 4; | |
+ } | |
+ | |
+ mx = masked_read (d, x); | |
+ msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); | |
+ | |
+ msum1 = _mm_hadd_ps (msum1, msum1); | |
+ msum1 = _mm_hadd_ps (msum1, msum1); | |
+ return _mm_cvtss_f32 (msum1); | |
+} | |
+ | |
+#endif | |
+ | |
#ifdef USE_AVX | |
// reads 0 <= d < 8 floats as __m256 | |
@@ -560,7 +582,7 @@ | |
return _mm_cvtss_f32 (msum2); | |
} | |
-#else | |
+#elif defined(__SSE__) | |
/* SSE-implementation of L2 distance */ | |
float fvec_L2sqr (const float * x, | |
@@ -618,29 +640,30 @@ | |
} | |
+#else | |
+// CPU-only implementation | |
-#endif | |
- | |
-float fvec_norm_L2sqr (const float * x, | |
- size_t d) | |
+float fvec_L2sqr (const float * x, | |
+ const float * y, | |
+ size_t d) | |
{ | |
- __m128 mx; | |
- __m128 msum1 = _mm_setzero_ps(); | |
- | |
- while (d >= 4) { | |
- mx = _mm_loadu_ps (x); x += 4; | |
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); | |
- d -= 4; | |
- } | |
+ return fvec_L2sqr_ref (x, y, d); | |
+} | |
- mx = masked_read (d, x); | |
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); | |
+float fvec_inner_product (const float * x, | |
+ const float * y, | |
+ size_t d) | |
+{ | |
+ return fvec_inner_product_ref (x, y, d); | |
+} | |
- msum1 = _mm_hadd_ps (msum1, msum1); | |
- msum1 = _mm_hadd_ps (msum1, msum1); | |
- return _mm_cvtss_f32 (msum1); | |
+float fvec_norm_L2sqr (const float *x, size_t d) | |
+{ | |
+ return fvec_norm_L2sqr_ref (x, d); | |
} | |
+#endif | |
+ | |
@@ -1861,6 +1884,7 @@ | |
c[i] = a[i] + bf * b[i]; | |
} | |
+#ifdef __SSE__ | |
static inline void fvec_madd_sse (size_t n, const float *a, | |
float bf, const float *b, float *c) { | |
@@ -1879,15 +1903,25 @@ | |
} | |
void fvec_madd (size_t n, const float *a, | |
- float bf, const float *b, float *c) | |
+ float bf, const float *b, float *c) | |
{ | |
if ((n & 3) == 0 && | |
((((long)a) | ((long)b) | ((long)c)) & 15) == 0) | |
fvec_madd_sse (n, a, bf, b, c); | |
else | |
fvec_madd_ref (n, a, bf, b, c); | |
} | |
+#else | |
+ | |
+void fvec_madd (size_t n, const float *a, | |
+ float bf, const float *b, float *c) | |
+{ | |
+ fvec_madd_ref (n, a, bf, b, c); | |
+} | |
+ | |
+#endif | |
+ | |
static inline int fvec_madd_and_argmin_ref (size_t n, const float *a, | |
float bf, const float *b, float *c) { | |
float vmin = 1e20; | |
@@ -1903,8 +1937,11 @@ | |
return imin; | |
} | |
-static inline int fvec_madd_and_argmin_sse (size_t n, const float *a, | |
- float bf, const float *b, float *c) { | |
+#ifdef __SSE__ | |
+ | |
+static inline int fvec_madd_and_argmin_sse ( | |
+ size_t n, const float *a, | |
+ float bf, const float *b, float *c) { | |
n >>= 2; | |
__m128 bf4 = _mm_set_ps1 (bf); | |
__m128 vmin4 = _mm_set_ps1 (1e20); | |
@@ -1953,15 +1990,24 @@ | |
int fvec_madd_and_argmin (size_t n, const float *a, | |
- float bf, const float *b, float *c) | |
+ float bf, const float *b, float *c) | |
{ | |
if ((n & 3) == 0 && | |
((((long)a) | ((long)b) | ((long)c)) & 15) == 0) | |
return fvec_madd_and_argmin_sse (n, a, bf, b, c); | |
else | |
return fvec_madd_and_argmin_ref (n, a, bf, b, c); | |
} | |
+#else | |
+ | |
+int fvec_madd_and_argmin (size_t n, const float *a, | |
+ float bf, const float *b, float *c) | |
+{ | |
+ return fvec_madd_and_argmin_ref (n, a, bf, b, c); | |
+} | |
+ | |
+#endif | |
const float *fvecs_maybe_subsample ( | |
diff --git a/IndexHNSW.cpp b/IndexHNSW.cpp | |
--- a/IndexHNSW.cpp | |
+++ b/IndexHNSW.cpp | |
@@ -19,7 +19,9 @@ | |
#include <unistd.h> | |
#include <stdint.h> | |
+#ifdef __SSE__ | |
#include <immintrin.h> | |
+#endif | |
#include "utils.h" | |
#include "Heap.h" | |
@@ -1862,6 +1864,7 @@ | |
float operator () (storage_idx_t i) override | |
{ | |
+#ifdef __SSE__ | |
const uint8_t *code = storage.codes.data() + i * storage.code_size; | |
long key = 0; | |
memcpy (&key, code, storage.code_size_1); | |
@@ -1885,6 +1888,9 @@ | |
accu = _mm_hadd_ps (accu, accu); | |
accu = _mm_hadd_ps (accu, accu); | |
return _mm_cvtss_f32 (accu); | |
+#else | |
+ FAISS_THROW_MSG("not implemented for non-x64 platforms"); | |
+#endif | |
} | |
}; | |
@@ -1913,6 +1919,7 @@ | |
long key01 = 0; | |
memcpy (&key01, code, storage.code_size_1); | |
code += storage.code_size_1; | |
+#ifdef __SSE__ | |
// walking pointers | |
const float *qa = q; | |
@@ -1938,6 +1945,9 @@ | |
accu = _mm_hadd_ps (accu, accu); | |
accu = _mm_hadd_ps (accu, accu); | |
return _mm_cvtss_f32 (accu); | |
+#else | |
+ FAISS_THROW_MSG("not implemented for non-x64 platforms"); | |
+#endif | |
} | |
}; | |
@@ -1950,6 +1960,7 @@ | |
dynamic_cast<Index2Layer*>(storage); | |
if (storage2l) { | |
+#ifdef __SSE__ | |
const MultiIndexQuantizer *mi = | |
dynamic_cast<MultiIndexQuantizer*> (storage2l->q1.quantizer); | |
@@ -1964,6 +1975,7 @@ | |
if (fl && storage2l->pq.dsub == 4) { | |
return new DistanceXPQ4(*storage2l); | |
} | |
+#endif | |
} | |
// IVFPQ and cases not handled above |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment