Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created October 19, 2022 19:05
Show Gist options
  • Save rygorous/4d9e9e88cab13c703773dc767a23575f to your computer and use it in GitHub Desktop.
Save rygorous/4d9e9e88cab13c703773dc767a23575f to your computer and use it in GitHub Desktop.
float<->half matching VCVTPS2PH exactly
#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>
// Float->half conversion with round-to-nearest-even, SSE2+
// leaves half-floats in 32-bit lanes (sign extended)
static inline __m128i F32_to_F16_4x(const __m128 &f)
{
const __m128 mask_sign = _mm_set1_ps(-0.0f);
const __m128i c_f16max = _mm_set1_epi32((127 + 16) << 23); // all FP32 values >=this round to +inf
const __m128i c_nanbit = _mm_set1_epi32(0x200);
const __m128i c_nanlobits = _mm_set1_epi32(0x1ff);
const __m128i c_infty_as_fp16 = _mm_set1_epi32(0x7c00);
const __m128i c_min_normal = _mm_set1_epi32((127 - 14) << 23); // smallest FP32 that yields a normalized FP16
const __m128i c_subnorm_magic = _mm_set1_epi32(((127 - 15) + (23 - 10) + 1) << 23);
const __m128i c_normal_bias = _mm_set1_epi32(0xfff - ((127 - 15) << 23)); // adjust exponent and add mantissa rounding
__m128 justsign = _mm_and_ps(f, mask_sign);
__m128 absf = _mm_andnot_ps(mask_sign, f); // f & ~mask_sign
__m128i absf_int = _mm_castps_si128(absf); // the cast is "free" (extra bypass latency, but no thruput hit)
__m128 b_isnan = _mm_cmpunord_ps(absf, absf); // is this a NaN?
__m128i b_isregular = _mm_cmpgt_epi32(c_f16max, absf_int); // (sub)normalized or special?
__m128i nan_payload = _mm_and_si128(_mm_srli_epi32(absf_int, 13), c_nanlobits); // payload bits for NaNs
__m128i nan_quiet = _mm_or_si128(nan_payload, c_nanbit); // and set quiet bit
__m128i nanfinal = _mm_and_si128(_mm_castps_si128(b_isnan), nan_quiet);
__m128i inf_or_nan = _mm_or_si128(nanfinal, c_infty_as_fp16); // output for specials
// subnormal?
__m128i b_issub = _mm_cmpgt_epi32(c_min_normal, absf_int);
// "result is subnormal" path
__m128 subnorm1 = _mm_add_ps(absf, _mm_castsi128_ps(c_subnorm_magic)); // magic value to round output mantissa
__m128i subnorm2 = _mm_sub_epi32(_mm_castps_si128(subnorm1), c_subnorm_magic); // subtract out bias
// "result is normal" path
__m128i mantoddbit = _mm_slli_epi32(absf_int, 31 - 13); // shift bit 13 (mantissa LSB) to sign
__m128i mantodd = _mm_srai_epi32(mantoddbit, 31); // -1 if FP16 mantissa odd, else 0
__m128i round1 = _mm_add_epi32(absf_int, c_normal_bias);
__m128i round2 = _mm_sub_epi32(round1, mantodd); // if mantissa LSB odd, bias towards rounding up (RTNE)
__m128i normal = _mm_srli_epi32(round2, 13); // rounded result
// combine the two non-specials
__m128i nonspecial = _mm_or_si128(_mm_and_si128(subnorm2, b_issub), _mm_andnot_si128(b_issub, normal));
// merge in specials as well
__m128i joined = _mm_or_si128(_mm_and_si128(nonspecial, b_isregular), _mm_andnot_si128(b_isregular, inf_or_nan));
__m128i sign_shift = _mm_srai_epi32(_mm_castps_si128(justsign), 16);
__m128i result = _mm_or_si128(joined, sign_shift);
return result;
}
// Half->float conversion, SSE2+
// input in 32-bit lanes
static inline __m128 F16_to_F32_4x(const __m128i &h)
{
const __m128i mask_nosign = _mm_set1_epi32(0x7fff);
const __m128 magic_mult = _mm_castsi128_ps(_mm_set1_epi32((254 - 15) << 23));
const __m128i was_infnan = _mm_set1_epi32(0x7bff);
const __m128 exp_infnan = _mm_castsi128_ps(_mm_set1_epi32(255 << 23));
const __m128i was_nan = _mm_set1_epi32(0x7c00);
const __m128i nan_quiet = _mm_set1_epi32(1 << 22);
__m128i expmant = _mm_and_si128(mask_nosign, h);
__m128i justsign = _mm_xor_si128(h, expmant);
__m128i shifted = _mm_slli_epi32(expmant, 13);
__m128 scaled = _mm_mul_ps(_mm_castsi128_ps(shifted), magic_mult);
__m128i b_wasinfnan = _mm_cmpgt_epi32(expmant, was_infnan);
__m128i sign = _mm_slli_epi32(justsign, 16);
__m128 infnanexp = _mm_and_ps(_mm_castsi128_ps(b_wasinfnan), exp_infnan);
__m128i b_wasnan = _mm_cmpgt_epi32(expmant, was_nan);
__m128i nanquiet = _mm_and_si128(b_wasnan, nan_quiet);
__m128 infnandone = _mm_or_ps(infnanexp, _mm_castsi128_ps(nanquiet));
__m128 sign_inf = _mm_or_ps(_mm_castsi128_ps(sign), infnandone);
__m128 result = _mm_or_ps(scaled, sign_inf);
return result;
}
// helper fn, only used for debug output
static uint32_t rrCtz32(uint32_t x)
{
if (x == 0)
return 32;
// x is not 0, so this eventually terminates
uint32_t count = 0;
while ((x & 1) == 0)
{
++count;
x >>= 1;
}
return x;
}
// tests against HW VCVTPS2PH which is our reference
static bool test_float_to_half_sse2()
{
printf("float->half SSE2:\n");
__m128i x = _mm_setr_epi32(0,1,2,3);
uint32_t base = 0;
do
{
if ((base & 0xffffff) == 0)
printf("\r%02x", base >> 24);
__m128 x_flt = _mm_castsi128_ps(x);
__m128i ref16 = _mm_cvtps_ph(x_flt, 0);
__m128i ref32 = _mm_cvtepi16_epi32(ref16); // sign extend to 16 bits
__m128i test = F32_to_F16_4x(x_flt);
__m128i compare = _mm_cmpeq_epi32(ref32, test);
int match = _mm_movemask_epi8(compare);
if (match != 0xffff)
{
uint32_t ref_res[4];
uint32_t test_res[4];
_mm_storeu_si128((__m128i *)ref_res, ref32);
_mm_storeu_si128((__m128i *)test_res, test);
int lane = rrCtz32(~match) / 4;
printf("\nmismatch!\n");
printf("x=0x%08x ref=0x%04x test=0x%04x\n", base + lane, ref_res[lane] & 0xffff, test_res[lane] & 0xffff);
return false;
}
base += 4;
x = _mm_add_epi32(x, _mm_set1_epi32(4));
} while (base != 0);
printf("\nall ok\n");
return true;
}
// tests against HW VCVTPH2PS which is our reference
static bool test_half_to_float_sse2()
{
printf("half->float SSE2:\n");
__m128i x = _mm_setr_epi32(0,1,2,3);
for (int base = 0; base < 0x10000; base += 4)
{
// sign extend as 16-bit values
__m128i sext16 = _mm_srai_epi32(_mm_slli_epi32(x, 16), 16);
// pack down to 16 bits
__m128i packed = _mm_packs_epi32(sext16, sext16);
// convert using HW instrs and SSE2 path
__m128 ref = _mm_cvtph_ps(packed);
__m128 test = F16_to_F32_4x(x);
// compare as ints
__m128i compare = _mm_cmpeq_epi32(_mm_castps_si128(ref), _mm_castps_si128(test));
int match = _mm_movemask_epi8(compare);
if (match != 0xffff)
{
uint32_t ref_res[4];
uint32_t test_res[4];
_mm_storeu_si128((__m128i *)ref_res, _mm_castps_si128(ref));
_mm_storeu_si128((__m128i *)test_res, _mm_castps_si128(test));
int lane = rrCtz32(~match) / 4;
printf("mismatch!\n");
printf("x=0x%04x ref=0x%08x test=0x%08x\n", base + lane, ref_res[lane], test_res[lane]);
return false;
}
x = _mm_add_epi32(x, _mm_set1_epi32(4));
}
printf("all ok!\n");
return true;
}
int main()
{
if (test_half_to_float_sse2() &&
test_float_to_half_sse2())
{
return 0;
}
// Error occurred!
return 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment