Skip to content

Instantly share code, notes, and snippets.

@kohnakagawa
Created January 4, 2017 08:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kohnakagawa/11a296e2ed1c155a66db1ae08d0c3ae5 to your computer and use it in GitHub Desktop.
Save kohnakagawa/11a296e2ed1c155a66db1ae08d0c3ae5 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <algorithm>
#include <random>
#include <chrono>
#include <random>
#include <x86intrin.h>
const double SL = 0.2;
const double SL2 = SL * SL;
typedef int32_t v4si __attribute__((vector_size(16)));
typedef int64_t v4di __attribute__((vector_size(32)));
typedef double v4df __attribute__((vector_size(32)));
struct Vec {
double x, y, z, w;
Vec(const double x_,
const double y_,
const double z_,
const double w_) {
x = x_;
y = y_;
z = z_;
w = w_;
}
Vec() {
x = y = z = w = 0.0;
}
};
void reference(const Vec* q,
const int32_t size,
int64_t* pair) {
for (int i = 0; i < size; i++) {
const auto r2
= q[i].x * q[i].x
+ q[i].y * q[i].y
+ q[i].z * q[i].z;
pair[i] = (r2 <= SL2) ? 0xffffffffffffffff : 0;
}
}
void with_gather(const Vec* q,
const int32_t size,
int64_t* pair) {
v4si vindex = _mm_set_epi32(12, 8, 4, 0);
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2);
for (int i = 0; i < (size / 4) * 4; i += 4) {
v4df vx = _mm256_i32gather_pd(&(q[i].x), vindex, 8);
v4df vy = _mm256_i32gather_pd(&(q[i].y), vindex, 8);
v4df vz = _mm256_i32gather_pd(&(q[i].z), vindex, 8);
v4df vr2 = vx * vx + vy * vy + vz * vz;
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag);
}
}
void without_gather(const Vec* q,
const int32_t size,
int64_t* pair) {
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2);
for (int i = 0; i < (size / 4) * 4; i += 4) {
v4df vqa = _mm256_load_pd(reinterpret_cast<const double*>(q + i ));
v4df vqb = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 1));
v4df vqc = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 2));
v4df vqd = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 3));
// transpose 4x4
v4df tmp0 = _mm256_unpacklo_pd(vqa, vqb);
v4df tmp1 = _mm256_unpackhi_pd(vqa, vqb);
v4df tmp2 = _mm256_unpacklo_pd(vqc, vqd);
v4df tmp3 = _mm256_unpackhi_pd(vqc, vqd);
v4df vx = _mm256_permute2f128_pd(tmp0, tmp2, 0x20);
v4df vy = _mm256_permute2f128_pd(tmp1, tmp3, 0x20);
v4df vz = _mm256_permute2f128_pd(tmp0, tmp2, 0x31);
v4df vr2 = vx * vx + vy * vy + vz * vz;
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag);
}
}
void load_seq(const Vec* q,
const int32_t size,
int64_t* pair) {
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2);
for (int i = 0; i < (size / 4) * 4; i += 4) {
v4df vx = _mm256_set_pd(q[i + 3].x, q[i + 2].x, q[i + 1].x, q[i].x);
v4df vy = _mm256_set_pd(q[i + 3].y, q[i + 2].y, q[i + 1].y, q[i].y);
v4df vz = _mm256_set_pd(q[i + 3].z, q[i + 2].z, q[i + 1].z, q[i].z);
v4df vr2 = vx * vx + vy * vy + vz * vz;
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag);
}
}
void check(const int64_t* ref,
const int64_t* tuned,
const int size) {
for (int i = 0; i < size; i++) {
if (ref[i] != tuned[i]) {
std::cerr << "Fail.\n";
return;
}
}
std::cerr << "Success.\n";
}
const int size = 5000000;
Vec q[size];
int64_t pair[size], pair_w_gather[size], pair_wo_gather[size], pair_load_seq[size];
int main() {
std::mt19937 mt;
std::uniform_real_distribution<> urd;
std::generate(q, q + size,
[&mt, &urd]() {
return Vec(urd(mt), urd(mt), urd(mt), urd(mt));
});
auto beg = std::chrono::system_clock::now();
reference(q, size, pair);
auto end = std::chrono::system_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl;
beg = std::chrono::system_clock::now();
with_gather(q, size, pair_w_gather);
end = std::chrono::system_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl;
beg = std::chrono::system_clock::now();
without_gather(q, size, pair_wo_gather);
end = std::chrono::system_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl;
beg = std::chrono::system_clock::now();
load_seq(q, size, pair_load_seq);
end = std::chrono::system_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl;
check(pair, pair_w_gather, size);
check(pair, pair_wo_gather, size);
check(pair, pair_load_seq, size);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment