Created
February 1, 2023 21:28
-
-
Save mfkasim1/14190cc7e701d605767b85a33772bd44 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// g++ -O2 div-avx.cc -I/home/muhammad/libs/sleef/include/ -mavx -L/home/muhammad/libs/sleef/lib/ -lsleef | |
#include <immintrin.h> | |
#include <sleef.h> | |
#include <chrono> | |
#include <iostream> | |
#include <iomanip> | |
#include <cmath> | |
#include <complex> | |
#define FIXED_FLOAT(x) std::fixed << std::setprecision(16) << std::abs(x) | |
std::complex<float> elmt_div(std::complex<float>& lhs, std::complex<float>& rhs) { | |
auto ar = lhs.real(); | |
auto ai = lhs.imag(); | |
auto br = rhs.real(); | |
auto bi = rhs.imag(); | |
// std::abs is already constexpr by gcc | |
auto abs_br = std::abs(br); | |
auto abs_bi = std::abs(bi); | |
if (abs_br >= abs_bi) { | |
if (abs_br == 0 && abs_bi == 0) { | |
/* divide by zeros should yield a complex inf or nan */ | |
auto real = ar / abs_br; | |
auto imag = ai / abs_bi; | |
return {real, imag}; | |
} else { | |
auto rat = bi / br; | |
auto scl = 1.0 / (br + bi * rat); | |
auto real = (ar + ai * rat) * scl; | |
auto imag = (ai - ar * rat) * scl; | |
return {real, imag}; | |
} | |
} else { | |
auto rat = br / bi; | |
auto scl = 1.0 / (bi + br * rat); | |
auto real = (ar * rat + ai) * scl; | |
auto imag = (ai * rat - ar) * scl; | |
return {real, imag}; | |
} | |
} | |
__m256 div_map(__m256 a, __m256 b) { | |
std::complex<float> tmp[4]; | |
for (int i = 0; i < 4; ++i) { | |
tmp[i] = elmt_div(((std::complex<float>*)&a)[i], ((std::complex<float>*)&b)[i]); | |
} | |
return _mm256_loadu_ps(reinterpret_cast<const float*>(tmp)); | |
} | |
__m256 abs_2_(__m256 values) { | |
auto val_2 = _mm256_mul_ps(values, values); // a*a b*b | |
auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b | |
return _mm256_permute_ps(ret, 0xD8); | |
} | |
__m256 div(__m256 a, __m256 b) { | |
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); | |
auto ac_bd = _mm256_mul_ps(a, b); //ac bd | |
auto d_c = _mm256_permute_ps(b, 0xB1); //d c | |
d_c = _mm256_xor_ps(sign_mask, d_c); //-d c | |
auto ad_bc = _mm256_mul_ps(a, d_c); //-ad bc | |
auto re_im = _mm256_hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad | |
re_im = _mm256_permute_ps(re_im, 0xD8); | |
return _mm256_div_ps(re_im, abs_2_(b)); | |
} | |
__m256 abs_new(__m256 values) { | |
// values: a + ib | |
// not using abs_2_ to prevent overflow/underflow for large/small numbers | |
auto mask = _mm256_set1_ps(-0.f); | |
auto fabs_val = _mm256_andnot_ps(mask, values); // |a| |b| | |
auto fabs_shf = _mm256_permute_ps(fabs_val, 0xB1); // |b| |a| | |
auto fabs_max = _mm256_max_ps(fabs_val, fabs_shf); // max max | |
auto fabs_min = _mm256_min_ps(fabs_val, fabs_shf); // min min | |
// following: max * sqrt(1 + min / max) | |
auto t = _mm256_div_ps(fabs_min, fabs_max); | |
auto t2 = _mm256_mul_ps(t, t); | |
auto t21 = _mm256_add_ps(t2, _mm256_set1_ps(1.0f)); | |
auto t21_sqrt = _mm256_sqrt_ps(t21); | |
auto res = _mm256_mul_ps(t21_sqrt, fabs_max); | |
// substitute res == 0 where fabs_max == 0 | |
auto zero = _mm256_set1_ps(0.f); | |
auto maskz = _mm256_cmp_ps(zero, fabs_max, _CMP_EQ_OQ); | |
res = _mm256_blendv_ps(res, zero, maskz); | |
return res; | |
} | |
__m256 div_new(__m256 a, __m256 b) { | |
auto b_abs = abs_new(b); // |c,d| |c,d| | |
auto a2 = _mm256_div_ps(a, b_abs); // a/|c,d| b/|c,d| | |
auto b2 = _mm256_div_ps(b, b_abs); // c/|c,d| d/|c,d| | |
auto acbd2 = _mm256_mul_ps(a2, b2); // ac/|c,d|^2 bd/|c,d|^2 | |
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); | |
auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/|c,d| c/|c,d| | |
dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/|c,d| | |
auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/|c,d|^2 bc/|c,d|^2 | |
auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/|c,d|^2 (bc-ad)/|c,d|^2 | |
// res2 above is not interleaved, needs permute_ps to make real and imag interleaved | |
res2 = _mm256_permute_ps(res2, 0xD8); | |
return res2; | |
} | |
__m256 div_new2(__m256 a, __m256 b) { | |
auto mask = _mm256_set1_ps(-0.f); | |
auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d| | |
auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c| | |
auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc | |
auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc | |
auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc | |
auto acbd2 = _mm256_mul_ps(a2, b2); | |
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); | |
auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc | |
dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc | |
auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 | |
auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 | |
res2 = _mm256_permute_ps(res2, 0xD8); | |
// get the denominator | |
auto denom2 = abs_2_(b2); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 | |
res2 = _mm256_div_ps(res2, denom2); | |
return res2; | |
} | |
int main() { | |
__m256 b = _mm256_setr_ps(1.2, 1.0, 3.1, 2.1, 5.2, 2.1, 5.5, -1.2); | |
int N = 1000000; | |
auto start = std::chrono::system_clock::now(); | |
__m256 d = _mm256_set1_ps(1.02); | |
for (int i = 0; i < N; ++i) { | |
d = div(b, d); | |
} | |
auto end = std::chrono::system_clock::now(); | |
auto elapsed = end - start; | |
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n'; | |
start = std::chrono::system_clock::now(); | |
d = _mm256_set1_ps(1.02); | |
for (int i = 0; i < N; ++i) { | |
d = div(b, d); | |
} | |
end = std::chrono::system_clock::now(); | |
elapsed = end - start; | |
std::cout << "time for orig div: " << elapsed.count() << '\n'; | |
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n'; | |
start = std::chrono::system_clock::now(); | |
d = _mm256_set1_ps(1.02); | |
for (int i = 0; i < N; ++i) { | |
d = div_new(b, d); | |
} | |
end = std::chrono::system_clock::now(); | |
elapsed = end - start; | |
std::cout << "time for new div: " << elapsed.count() << '\n'; | |
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n'; | |
start = std::chrono::system_clock::now(); | |
d = _mm256_set1_ps(1.02); | |
for (int i = 0; i < N; ++i) { | |
d = div_new2(b, d); | |
} | |
end = std::chrono::system_clock::now(); | |
elapsed = end - start; | |
std::cout << "time for new div2: " << elapsed.count() << '\n'; | |
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n'; | |
start = std::chrono::system_clock::now(); | |
d = _mm256_set1_ps(1.02); | |
for (int i = 0; i < N; ++i) { | |
d = div_map(b, d); | |
} | |
end = std::chrono::system_clock::now(); | |
elapsed = end - start; | |
std::cout << "time for div_map : " << elapsed.count() << '\n'; | |
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n'; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment