Skip to content

Instantly share code, notes, and snippets.

@syoyo
Created July 5, 2020 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syoyo/62c0ec0183e91bff668624eba91f76f4 to your computer and use it in GitHub Desktop.
Save syoyo/62c0ec0183e91bff668624eba91f76f4 to your computer and use it in GitHub Desktop.
_mm_min_ps implementation in ARM NEON
#include <arm_neon.h>
#include <cstdio>
#include <limits>
#include <cassert>
#include <cmath>
#include <cstdint>
bool check_snan(float f)
{
bool is_nan = std::isnan(f);
uint32_t val = *reinterpret_cast<uint32_t *>(&f);
bool bit_qnan = val & 0x00400000; // qNaN bit
return is_nan && (!bit_qnan);
}
bool check_qnan(float f)
{
uint32_t val = *reinterpret_cast<uint32_t *>(&f);
bool is_qnan = val & 0x7fc00000; // exp + qNaN bit
return is_qnan;
}
// Check if input is sNaN
inline uint32x4_t is_snan(float32x4_t a)
{
// all exp bits are 1 and MSB bit of mantissa is 1
const uint32x4_t vsnan_mask = {0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000};
uint32x4_t ret = vceqq_u32(vandq_u32(vreinterpretq_u32_f32(a), vsnan_mask), vsnan_mask);
__attribute__((aligned(16))) uint32_t mbuf[4];
vst1q_u32(mbuf, ret);
printf("v_is_snan = %x, %x, %x, %x\n",
mbuf[0],
mbuf[1],
mbuf[2],
mbuf[3]);
return ret;
}
// Check if input is NaN(sNaN or qNan)
inline uint32x4_t is_nan(float32x4_t a)
{
const uint32x4_t vexp_mask = {0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000};
const uint32x4_t vmantissa_mask = {0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff};
const uint32x4_t vzero = vdupq_n_u32(0);
// Check if all exp bits are 1.
uint32x4_t v_exp_all_ones = vceqq_u32(vandq_u32(vreinterpretq_u32_f32(a), vexp_mask), vexp_mask);
// Check if any mantissa bits are on(qNaN or sNaN)
uint32x4_t v_mantissa_any = vcgtq_u32(vandq_u32(vreinterpretq_u32_f32(a), vmantissa_mask), vzero);
uint32x4_t v_is_nan = vandq_u32(v_exp_all_ones, v_mantissa_any);
__attribute__((aligned(16))) uint32_t mbuf[4];
vst1q_u32(mbuf, v_is_nan);
printf("v_is_nan = %x, %x, %x, %x\n",
mbuf[0],
mbuf[1],
mbuf[2],
mbuf[3]);
return v_is_nan;
}
void print_u32(uint32x4_t v, const char *title)
{
__attribute__((aligned(16))) uint32_t mbuf[4];
vst1q_u32(mbuf, v);
printf("%s = %x, %x, %x, %x\n",
title,
mbuf[0],
mbuf[1],
mbuf[2],
mbuf[3]);
}
void print_f32(float32x4_t v, const char *title)
{
__attribute__((aligned(16))) float mbuf[4];
vst1q_f32(mbuf, v);
printf("%s = %f, %f, %f, %f\n",
title,
mbuf[0],
mbuf[1],
mbuf[2],
mbuf[3]);
}
inline float32x4_t vmin(float32x4_t a, float32x4_t b)
{
//
// Accurate simulation of _mm_min_ps using ARM NEON
//
// https://www.felixcloutier.com/x86/minps
//
// when both input are (+/-)0.0, return the second
// when the first input is NaN(sNaN or qNaN), return the second.
// when the second input is sNaN, return sNaN(return the second).
// otherwise return min(a, b)
//
const uint32x4_t vzero = vdupq_n_f32(0.0f);
const uint32x4_t v_src1_is_snan = is_snan(b);
// fortunately, ceqq_f32 ignores the sign.
const uint32x4_t v_both_are_zeros = vandq_u32(vreinterpretq_u32_f32(vceqq_f32(a, vzero)),
vreinterpretq_u32_f32(vceqq_f32(b, vzero)));
const uint32x4_t v_src0_is_nan = is_nan(a);
const float32x4_t v_min = vminq_f32(a, b);
print_u32(v_src0_is_nan, "src0 is NaN");
print_u32(v_src1_is_snan, "src1 is sNaN");
print_u32(v_both_are_zeros, "both src0 and src1 are zero");
float32x4_t v_special_case = vbslq_f32(v_both_are_zeros, b, v_min);
print_f32(v_min, "min(a, b)");
print_f32(v_special_case, "after both zero hadling");
v_special_case = vbslq_f32(v_src0_is_nan, b, v_special_case);
v_special_case = vbslq_f32(v_src1_is_snan, b, v_special_case);
// Requie NaN or both zero case handling?
const uint32x4_t v_require_special_handling = vorrq_u32(v_src1_is_snan, vorrq_u32(v_both_are_zeros, v_src0_is_nan));
print_u32(v_require_special_handling, "require special handling");
// use min(a, b) when !(require special handling)
float32x4_t ret = vbslq_f32(v_require_special_handling, v_special_case, v_min);
return ret;
}
int main()
{
float32x4_t v0, v1, a;
__attribute__((aligned(16))) float in0[4];
__attribute__((aligned(16))) float in1[4];
__attribute__((aligned(16))) float buf[4];
{
buf[0] = std::numeric_limits<float>::quiet_NaN();
buf[1] = std::numeric_limits<float>::signaling_NaN();
printf("qnan, snan = %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])));
}
in0[0] = 1.0f;
in0[1] = 2.0f;
in0[2] = std::numeric_limits<float>::infinity();
in0[3] = std::numeric_limits<float>::max();
in1[0] = 2.0f;
in1[1] = 1.0f;
in1[2] = std::numeric_limits<float>::max();
in1[3] = std::numeric_limits<float>::infinity();
v0 = vld1q_f32(in0);
v1 = vld1q_f32(in1);
a = vmin(v0, v1);
vst1q_f32(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("--------------\n");
uint32x4_t m = vceqq_f32(v0, v0);
__attribute__((aligned(16))) uint32_t mbuf[4];
printf("mbuf = %d, %d, %d, %d\n",
mbuf[0],
mbuf[1],
mbuf[2],
mbuf[3]);
a = vmaxnmq_f32(vminq_f32(v1, v0), v1);
vst1q_f32(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])),
*(reinterpret_cast<uint32_t *>(&buf[2])),
*(reinterpret_cast<uint32_t *>(&buf[3])));
printf("--------------\n");
in0[0] = std::numeric_limits<float>::quiet_NaN();
in0[1] = std::numeric_limits<float>::signaling_NaN();
in0[2] = 1.0f;
in0[3] = 2.0f;
in1[0] = 1.0f;
in1[1] = 2.0f;
in1[2] = std::numeric_limits<float>::quiet_NaN();
in1[3] = std::numeric_limits<float>::signaling_NaN();
v0 = vld1q_f32(in0);
v1 = vld1q_f32(in1);
a = vmin(v0, v1);
vst1q_f32(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])),
*(reinterpret_cast<uint32_t *>(&buf[2])),
*(reinterpret_cast<uint32_t *>(&buf[3])));
assert(std::fabs(buf[0] - 1.0f) < std::numeric_limits<float>::epsilon()); // 1.0
assert(std::fabs(buf[1] - 2.0f) < std::numeric_limits<float>::epsilon()); // 2.0
assert(check_qnan(buf[2])); // qNaN
assert(check_qnan(buf[3])); // sNaN
printf("--------------\n");
in0[0] = -0.0f;
in0[1] = 0.0f;
in0[2] = std::numeric_limits<float>::quiet_NaN();
in0[3] = std::numeric_limits<float>::signaling_NaN();
in1[0] = 0.0f;
in1[1] = -0.0f;
in1[2] = std::numeric_limits<float>::signaling_NaN();
in1[3] = std::numeric_limits<float>::quiet_NaN();
v0 = vld1q_f32(in0);
v1 = vld1q_f32(in1);
a = vmin(v0, v1);
vst1q_f32(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in0 hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&in0[0])),
*(reinterpret_cast<uint32_t *>(&in0[1])),
*(reinterpret_cast<uint32_t *>(&in0[2])),
*(reinterpret_cast<uint32_t *>(&in0[3])));
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("in1 hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&in1[0])),
*(reinterpret_cast<uint32_t *>(&in1[1])),
*(reinterpret_cast<uint32_t *>(&in1[2])),
*(reinterpret_cast<uint32_t *>(&in1[3])));
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])),
*(reinterpret_cast<uint32_t *>(&buf[2])),
*(reinterpret_cast<uint32_t *>(&buf[3])));
assert(*(reinterpret_cast<uint32_t *>(&buf[0])) == 0x00000000); // 0.0
assert(*(reinterpret_cast<uint32_t *>(&buf[1])) == 0x80000000); // -0.0
assert(check_snan(buf[2])); // sNaN
assert(check_qnan(buf[3])); // qNan
return 0;
}
/* ------------- sse2 -----------------------------------------------*/
#include <xmmintrin.h>
#include <cstdio>
#include <limits>
#include <cstdint>
#include <cmath>
#include <cassert>
bool check_snan(float f)
{
bool is_nan = std::isnan(f);
uint32_t val = *reinterpret_cast<uint32_t *>(&f);
bool bit_qnan = val & 0x00400000; // qNaN bit
printf("val = %x, is_nan = %d, bit_qnan = %d\n", val, is_nan, bit_qnan);
return is_nan && (!bit_qnan);
}
bool check_qnan(float f)
{
uint32_t val = *reinterpret_cast<uint32_t *>(&f);
bool is_qnan = val & 0x7fc00000; // exp + qNaN bit
return is_qnan;
}
int main()
{
__m128 v0;
__m128 v1;
__m128 a;
__attribute__((aligned(16))) float in0[4];
__attribute__((aligned(16))) float in1[4];
__attribute__((aligned(16))) float buf[4];
{
buf[0] = std::numeric_limits<float>::quiet_NaN();
buf[1] = std::numeric_limits<float>::signaling_NaN();
printf("qnan, snan = %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])));
}
in0[0] = 1.0f;
in0[1] = 2.0f;
in0[2] = std::numeric_limits<float>::infinity();
in0[3] = std::numeric_limits<float>::max();
in1[0] = 2.0f;
in1[1] = 1.0f;
in1[2] = std::numeric_limits<float>::max();
in1[3] = std::numeric_limits<float>::infinity();
v0 = _mm_load_ps(in0);
v1 = _mm_load_ps(in1);
a = _mm_min_ps(v0, v1);
_mm_store_ps(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("--------------\n");
in0[0] = std::numeric_limits<float>::quiet_NaN();
in0[1] = std::numeric_limits<float>::signaling_NaN();
in0[2] = 1.0f;
in0[3] = 2.0f;
in1[0] = 1.0f;
in1[1] = 2.0f;
in1[2] = std::numeric_limits<float>::quiet_NaN();
in1[3] = std::numeric_limits<float>::signaling_NaN();
v0 = _mm_load_ps(in0);
v1 = _mm_load_ps(in1);
a = _mm_min_ps(v0, v1);
_mm_store_ps(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])),
*(reinterpret_cast<uint32_t *>(&buf[2])),
*(reinterpret_cast<uint32_t *>(&buf[3])));
assert(std::fabs(buf[0] - 1.0f) < std::numeric_limits<float>::epsilon()); // 1.0
assert(std::fabs(buf[1] - 2.0f) < std::numeric_limits<float>::epsilon()); // 2.0
assert(check_qnan(buf[2])); // qNaN
assert(check_qnan(buf[3])); // sNaN
printf("----------------\n");
in0[0] = -0.0f;
in0[1] = 0.0f;
in0[2] = std::numeric_limits<float>::quiet_NaN();
in0[3] = std::numeric_limits<float>::signaling_NaN();
in1[0] = 0.0f;
in1[1] = -0.0f;
in1[2] = std::numeric_limits<float>::signaling_NaN();
in1[3] = std::numeric_limits<float>::quiet_NaN();
v0 = _mm_load_ps(in0);
v1 = _mm_load_ps(in1);
a = _mm_min_ps(v0, v1);
_mm_store_ps(buf, a);
printf("in0 = %f, %f, %f, %f\n",
in0[0],
in0[1],
in0[2],
in0[3]);
printf("in0 hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&in0[0])),
*(reinterpret_cast<uint32_t *>(&in0[1])),
*(reinterpret_cast<uint32_t *>(&in0[2])),
*(reinterpret_cast<uint32_t *>(&in0[3])));
printf("in1 = %f, %f, %f, %f\n",
in1[0],
in1[1],
in1[2],
in1[3]);
printf("in1 hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&in1[0])),
*(reinterpret_cast<uint32_t *>(&in1[1])),
*(reinterpret_cast<uint32_t *>(&in1[2])),
*(reinterpret_cast<uint32_t *>(&in1[3])));
printf("vmin = %f, %f, %f, %f\n",
buf[0],
buf[1],
buf[2],
buf[3]);
printf("hex = %x, %x, %x, %x\n",
*(reinterpret_cast<uint32_t *>(&buf[0])),
*(reinterpret_cast<uint32_t *>(&buf[1])),
*(reinterpret_cast<uint32_t *>(&buf[2])),
*(reinterpret_cast<uint32_t *>(&buf[3])));
assert(*(reinterpret_cast<uint32_t *>(&buf[0])) == 0x00000000); // 0.0
assert(*(reinterpret_cast<uint32_t *>(&buf[1])) == 0x80000000); // -0.0
assert(check_snan(buf[2])); // sNaN
assert(check_qnan(buf[3])); // qNan
printf("----------------\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment