Created
January 4, 2017 08:41
-
-
Save kohnakagawa/11a296e2ed1c155a66db1ae08d0c3ae5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <algorithm> | |
#include <random> | |
#include <chrono> | |
#include <random> | |
#include <x86intrin.h> | |
const double SL = 0.2; | |
const double SL2 = SL * SL; | |
typedef int32_t v4si __attribute__((vector_size(16))); | |
typedef int64_t v4di __attribute__((vector_size(32))); | |
typedef double v4df __attribute__((vector_size(32))); | |
struct Vec { | |
double x, y, z, w; | |
Vec(const double x_, | |
const double y_, | |
const double z_, | |
const double w_) { | |
x = x_; | |
y = y_; | |
z = z_; | |
w = w_; | |
} | |
Vec() { | |
x = y = z = w = 0.0; | |
} | |
}; | |
void reference(const Vec* q, | |
const int32_t size, | |
int64_t* pair) { | |
for (int i = 0; i < size; i++) { | |
const auto r2 | |
= q[i].x * q[i].x | |
+ q[i].y * q[i].y | |
+ q[i].z * q[i].z; | |
pair[i] = (r2 <= SL2) ? 0xffffffffffffffff : 0; | |
} | |
} | |
void with_gather(const Vec* q, | |
const int32_t size, | |
int64_t* pair) { | |
v4si vindex = _mm_set_epi32(12, 8, 4, 0); | |
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2); | |
for (int i = 0; i < (size / 4) * 4; i += 4) { | |
v4df vx = _mm256_i32gather_pd(&(q[i].x), vindex, 8); | |
v4df vy = _mm256_i32gather_pd(&(q[i].y), vindex, 8); | |
v4df vz = _mm256_i32gather_pd(&(q[i].z), vindex, 8); | |
v4df vr2 = vx * vx + vy * vy + vz * vz; | |
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS)); | |
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag); | |
} | |
} | |
void without_gather(const Vec* q, | |
const int32_t size, | |
int64_t* pair) { | |
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2); | |
for (int i = 0; i < (size / 4) * 4; i += 4) { | |
v4df vqa = _mm256_load_pd(reinterpret_cast<const double*>(q + i )); | |
v4df vqb = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 1)); | |
v4df vqc = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 2)); | |
v4df vqd = _mm256_load_pd(reinterpret_cast<const double*>(q + i + 3)); | |
// transpose 4x4 | |
v4df tmp0 = _mm256_unpacklo_pd(vqa, vqb); | |
v4df tmp1 = _mm256_unpackhi_pd(vqa, vqb); | |
v4df tmp2 = _mm256_unpacklo_pd(vqc, vqd); | |
v4df tmp3 = _mm256_unpackhi_pd(vqc, vqd); | |
v4df vx = _mm256_permute2f128_pd(tmp0, tmp2, 0x20); | |
v4df vy = _mm256_permute2f128_pd(tmp1, tmp3, 0x20); | |
v4df vz = _mm256_permute2f128_pd(tmp0, tmp2, 0x31); | |
v4df vr2 = vx * vx + vy * vy + vz * vz; | |
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS)); | |
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag); | |
} | |
} | |
void load_seq(const Vec* q, | |
const int32_t size, | |
int64_t* pair) { | |
v4df vsl2 = _mm256_set_pd(SL2, SL2, SL2, SL2); | |
for (int i = 0; i < (size / 4) * 4; i += 4) { | |
v4df vx = _mm256_set_pd(q[i + 3].x, q[i + 2].x, q[i + 1].x, q[i].x); | |
v4df vy = _mm256_set_pd(q[i + 3].y, q[i + 2].y, q[i + 1].y, q[i].y); | |
v4df vz = _mm256_set_pd(q[i + 3].z, q[i + 2].z, q[i + 1].z, q[i].z); | |
v4df vr2 = vx * vx + vy * vy + vz * vz; | |
v4di dr2_flag = _mm256_castpd_si256(_mm256_cmp_pd(vr2, vsl2, _CMP_LE_OS)); | |
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&pair[i]), dr2_flag); | |
} | |
} | |
void check(const int64_t* ref, | |
const int64_t* tuned, | |
const int size) { | |
for (int i = 0; i < size; i++) { | |
if (ref[i] != tuned[i]) { | |
std::cerr << "Fail.\n"; | |
return; | |
} | |
} | |
std::cerr << "Success.\n"; | |
} | |
const int size = 5000000; | |
Vec q[size]; | |
int64_t pair[size], pair_w_gather[size], pair_wo_gather[size], pair_load_seq[size]; | |
int main() { | |
std::mt19937 mt; | |
std::uniform_real_distribution<> urd; | |
std::generate(q, q + size, | |
[&mt, &urd]() { | |
return Vec(urd(mt), urd(mt), urd(mt), urd(mt)); | |
}); | |
auto beg = std::chrono::system_clock::now(); | |
reference(q, size, pair); | |
auto end = std::chrono::system_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl; | |
beg = std::chrono::system_clock::now(); | |
with_gather(q, size, pair_w_gather); | |
end = std::chrono::system_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl; | |
beg = std::chrono::system_clock::now(); | |
without_gather(q, size, pair_wo_gather); | |
end = std::chrono::system_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl; | |
beg = std::chrono::system_clock::now(); | |
load_seq(q, size, pair_load_seq); | |
end = std::chrono::system_clock::now(); | |
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - beg).count() << std::endl; | |
check(pair, pair_w_gather, size); | |
check(pair, pair_wo_gather, size); | |
check(pair, pair_load_seq, size); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment