Skip to content

Instantly share code, notes, and snippets.

@mfkasim1
Created February 1, 2023 21:28
Show Gist options
  • Save mfkasim1/14190cc7e701d605767b85a33772bd44 to your computer and use it in GitHub Desktop.
Save mfkasim1/14190cc7e701d605767b85a33772bd44 to your computer and use it in GitHub Desktop.
// g++ -O2 div-avx.cc -I/home/muhammad/libs/sleef/include/ -mavx -L/home/muhammad/libs/sleef/lib/ -lsleef
#include <immintrin.h>
#include <sleef.h>
#include <chrono>
#include <iostream>
#include <iomanip>
#include <cmath>
#include <complex>
#define FIXED_FLOAT(x) std::fixed << std::setprecision(16) << std::abs(x)
std::complex<float> elmt_div(std::complex<float>& lhs, std::complex<float>& rhs) {
auto ar = lhs.real();
auto ai = lhs.imag();
auto br = rhs.real();
auto bi = rhs.imag();
// std::abs is already constexpr by gcc
auto abs_br = std::abs(br);
auto abs_bi = std::abs(bi);
if (abs_br >= abs_bi) {
if (abs_br == 0 && abs_bi == 0) {
/* divide by zeros should yield a complex inf or nan */
auto real = ar / abs_br;
auto imag = ai / abs_bi;
return {real, imag};
} else {
auto rat = bi / br;
auto scl = 1.0 / (br + bi * rat);
auto real = (ar + ai * rat) * scl;
auto imag = (ai - ar * rat) * scl;
return {real, imag};
}
} else {
auto rat = br / bi;
auto scl = 1.0 / (bi + br * rat);
auto real = (ar * rat + ai) * scl;
auto imag = (ai * rat - ar) * scl;
return {real, imag};
}
}
__m256 div_map(__m256 a, __m256 b) {
std::complex<float> tmp[4];
for (int i = 0; i < 4; ++i) {
tmp[i] = elmt_div(((std::complex<float>*)&a)[i], ((std::complex<float>*)&b)[i]);
}
return _mm256_loadu_ps(reinterpret_cast<const float*>(tmp));
}
__m256 abs_2_(__m256 values) {
auto val_2 = _mm256_mul_ps(values, values); // a*a b*b
auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b
return _mm256_permute_ps(ret, 0xD8);
}
__m256 div(__m256 a, __m256 b) {
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto ac_bd = _mm256_mul_ps(a, b); //ac bd
auto d_c = _mm256_permute_ps(b, 0xB1); //d c
d_c = _mm256_xor_ps(sign_mask, d_c); //-d c
auto ad_bc = _mm256_mul_ps(a, d_c); //-ad bc
auto re_im = _mm256_hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad
re_im = _mm256_permute_ps(re_im, 0xD8);
return _mm256_div_ps(re_im, abs_2_(b));
}
__m256 abs_new(__m256 values) {
// values: a + ib
// not using abs_2_ to prevent overflow/underflow for large/small numbers
auto mask = _mm256_set1_ps(-0.f);
auto fabs_val = _mm256_andnot_ps(mask, values); // |a| |b|
auto fabs_shf = _mm256_permute_ps(fabs_val, 0xB1); // |b| |a|
auto fabs_max = _mm256_max_ps(fabs_val, fabs_shf); // max max
auto fabs_min = _mm256_min_ps(fabs_val, fabs_shf); // min min
// following: max * sqrt(1 + min / max)
auto t = _mm256_div_ps(fabs_min, fabs_max);
auto t2 = _mm256_mul_ps(t, t);
auto t21 = _mm256_add_ps(t2, _mm256_set1_ps(1.0f));
auto t21_sqrt = _mm256_sqrt_ps(t21);
auto res = _mm256_mul_ps(t21_sqrt, fabs_max);
// substitute res == 0 where fabs_max == 0
auto zero = _mm256_set1_ps(0.f);
auto maskz = _mm256_cmp_ps(zero, fabs_max, _CMP_EQ_OQ);
res = _mm256_blendv_ps(res, zero, maskz);
return res;
}
__m256 div_new(__m256 a, __m256 b) {
auto b_abs = abs_new(b); // |c,d| |c,d|
auto a2 = _mm256_div_ps(a, b_abs); // a/|c,d| b/|c,d|
auto b2 = _mm256_div_ps(b, b_abs); // c/|c,d| d/|c,d|
auto acbd2 = _mm256_mul_ps(a2, b2); // ac/|c,d|^2 bd/|c,d|^2
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/|c,d| c/|c,d|
dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/|c,d|
auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/|c,d|^2 bc/|c,d|^2
auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/|c,d|^2 (bc-ad)/|c,d|^2
// res2 above is not interleaved, needs permute_ps to make real and imag interleaved
res2 = _mm256_permute_ps(res2, 0xD8);
return res2;
}
__m256 div_new2(__m256 a, __m256 b) {
auto mask = _mm256_set1_ps(-0.f);
auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d|
auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c|
auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc
auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc
auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc
auto acbd2 = _mm256_mul_ps(a2, b2);
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc
dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc
auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2
auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
res2 = _mm256_permute_ps(res2, 0xD8);
// get the denominator
auto denom2 = abs_2_(b2); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
res2 = _mm256_div_ps(res2, denom2);
return res2;
}
int main() {
__m256 b = _mm256_setr_ps(1.2, 1.0, 3.1, 2.1, 5.2, 2.1, 5.5, -1.2);
int N = 1000000;
auto start = std::chrono::system_clock::now();
__m256 d = _mm256_set1_ps(1.02);
for (int i = 0; i < N; ++i) {
d = div(b, d);
}
auto end = std::chrono::system_clock::now();
auto elapsed = end - start;
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n';
start = std::chrono::system_clock::now();
d = _mm256_set1_ps(1.02);
for (int i = 0; i < N; ++i) {
d = div(b, d);
}
end = std::chrono::system_clock::now();
elapsed = end - start;
std::cout << "time for orig div: " << elapsed.count() << '\n';
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n';
start = std::chrono::system_clock::now();
d = _mm256_set1_ps(1.02);
for (int i = 0; i < N; ++i) {
d = div_new(b, d);
}
end = std::chrono::system_clock::now();
elapsed = end - start;
std::cout << "time for new div: " << elapsed.count() << '\n';
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n';
start = std::chrono::system_clock::now();
d = _mm256_set1_ps(1.02);
for (int i = 0; i < N; ++i) {
d = div_new2(b, d);
}
end = std::chrono::system_clock::now();
elapsed = end - start;
std::cout << "time for new div2: " << elapsed.count() << '\n';
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n';
start = std::chrono::system_clock::now();
d = _mm256_set1_ps(1.02);
for (int i = 0; i < N; ++i) {
d = div_map(b, d);
}
end = std::chrono::system_clock::now();
elapsed = end - start;
std::cout << "time for div_map : " << elapsed.count() << '\n';
std::cout << FIXED_FLOAT(((float*)&d)[0] - 1.02) << '\n';
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment