Skip to content

Instantly share code, notes, and snippets.

@kaja47
Created October 5, 2015 21:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kaja47/045e40ae329042b437fa to your computer and use it in GitHub Desktop.
Save kaja47/045e40ae329042b437fa to your computer and use it in GitHub Desktop.
jaccard similarity using AVX2 SIMD
#include <stdio.h>
#include <immintrin.h>
#include <stdint.h>
int intersectionSize(int* a, const int alen, int* b, const int blen) {
int ai = 0, bi = 0, size = 0;
while (ai < alen && bi < blen) {
int av = a[ai];
int bv = b[bi];
size += ((av == bv) ? 1 : 0);
ai += ((av <= bv) ? 1 : 0);
bi += ((av >= bv) ? 1 : 0);
}
return size;
}
#define SWAP(a, b) do { typeof(a) temp = a; a = b; b = temp; } while (0)
int intersectionSizeAVX(int* a, int alen, int* b, int blen) {
int *aa = a;
int *bb = b;
int size = 0;
int *a_end = a+alen-7;
int *b_end = b+blen-7;
while (a < a_end && b < b_end) {
if (*b > *a) {
SWAP(a, b);
SWAP(a_end, b_end);
SWAP(aa, bb);
SWAP(alen, blen);
}
__m256i q = _mm256_set1_epi32(*a);
__m256i bv = _mm256_loadu_si256((__m256i *)b);
__m256i cmp = _mm256_cmpgt_epi32(bv, q); // bv > q
uint32_t mask = ~_mm256_movemask_epi8(cmp);
uint32_t pos = 8 - _lzcnt_u32(mask) / 4;
b += pos;
size += ((*(b-1) == *a) ? 1 : 0);
a += (pos < 8 ? 1 : 0);
}
size += intersectionSize(a, aa+alen-a, b, bb+blen-b);
return size;
}
float jaccard(int* a, const int alen, int* b, const int blen) {
int is = intersectionSize(a, alen, b, blen);
int us = alen + blen - is;
return (us == 0) ? 0.0 : (is / (float) us);
}
float jaccardAVX(int* a, const int alen, int* b, const int blen) {
int is = intersectionSizeAVX(a, alen, b, blen);
int us = alen + blen - is;
return (us == 0) ? 0.0 : (is / (float) us);
}
int main(int argc, char *argv[]) {
int len = atoi(argv[1]);
int m = atoi(argv[2]);
int alen = len;
int blen = len;
int* a = malloc(alen * sizeof(int));
int* b = malloc(blen * sizeof(int));
for (int i = 0; i < alen; i++) a[i] = i;
for (int i = 0; i < blen; i++) b[i] = i*m;
long sum = 0;
for (int i = 0; i < 10000; i++) {
sum += intersectionSizeAVX(a, alen, b, blen);
}
printf("%ld\n", sum);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment