Created
October 19, 2022 19:05
-
-
Save rygorous/4d9e9e88cab13c703773dc767a23575f to your computer and use it in GitHub Desktop.
float<->half matching VCVTPS2PH exactly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdint.h> | |
#include <immintrin.h> | |
// Float->half conversion with round-to-nearest-even, SSE2+ | |
// leaves half-floats in 32-bit lanes (sign extended) | |
static inline __m128i F32_to_F16_4x(const __m128 &f) | |
{ | |
const __m128 mask_sign = _mm_set1_ps(-0.0f); | |
const __m128i c_f16max = _mm_set1_epi32((127 + 16) << 23); // all FP32 values >=this round to +inf | |
const __m128i c_nanbit = _mm_set1_epi32(0x200); | |
const __m128i c_nanlobits = _mm_set1_epi32(0x1ff); | |
const __m128i c_infty_as_fp16 = _mm_set1_epi32(0x7c00); | |
const __m128i c_min_normal = _mm_set1_epi32((127 - 14) << 23); // smallest FP32 that yields a normalized FP16 | |
const __m128i c_subnorm_magic = _mm_set1_epi32(((127 - 15) + (23 - 10) + 1) << 23); | |
const __m128i c_normal_bias = _mm_set1_epi32(0xfff - ((127 - 15) << 23)); // adjust exponent and add mantissa rounding | |
__m128 justsign = _mm_and_ps(f, mask_sign); | |
__m128 absf = _mm_andnot_ps(mask_sign, f); // f & ~mask_sign | |
__m128i absf_int = _mm_castps_si128(absf); // the cast is "free" (extra bypass latency, but no thruput hit) | |
__m128 b_isnan = _mm_cmpunord_ps(absf, absf); // is this a NaN? | |
__m128i b_isregular = _mm_cmpgt_epi32(c_f16max, absf_int); // (sub)normalized or special? | |
__m128i nan_payload = _mm_and_si128(_mm_srli_epi32(absf_int, 13), c_nanlobits); // payload bits for NaNs | |
__m128i nan_quiet = _mm_or_si128(nan_payload, c_nanbit); // and set quiet bit | |
__m128i nanfinal = _mm_and_si128(_mm_castps_si128(b_isnan), nan_quiet); | |
__m128i inf_or_nan = _mm_or_si128(nanfinal, c_infty_as_fp16); // output for specials | |
// subnormal? | |
__m128i b_issub = _mm_cmpgt_epi32(c_min_normal, absf_int); | |
// "result is subnormal" path | |
__m128 subnorm1 = _mm_add_ps(absf, _mm_castsi128_ps(c_subnorm_magic)); // magic value to round output mantissa | |
__m128i subnorm2 = _mm_sub_epi32(_mm_castps_si128(subnorm1), c_subnorm_magic); // subtract out bias | |
// "result is normal" path | |
__m128i mantoddbit = _mm_slli_epi32(absf_int, 31 - 13); // shift bit 13 (mantissa LSB) to sign | |
__m128i mantodd = _mm_srai_epi32(mantoddbit, 31); // -1 if FP16 mantissa odd, else 0 | |
__m128i round1 = _mm_add_epi32(absf_int, c_normal_bias); | |
__m128i round2 = _mm_sub_epi32(round1, mantodd); // if mantissa LSB odd, bias towards rounding up (RTNE) | |
__m128i normal = _mm_srli_epi32(round2, 13); // rounded result | |
// combine the two non-specials | |
__m128i nonspecial = _mm_or_si128(_mm_and_si128(subnorm2, b_issub), _mm_andnot_si128(b_issub, normal)); | |
// merge in specials as well | |
__m128i joined = _mm_or_si128(_mm_and_si128(nonspecial, b_isregular), _mm_andnot_si128(b_isregular, inf_or_nan)); | |
__m128i sign_shift = _mm_srai_epi32(_mm_castps_si128(justsign), 16); | |
__m128i result = _mm_or_si128(joined, sign_shift); | |
return result; | |
} | |
// Half->float conversion, SSE2+ | |
// input in 32-bit lanes | |
static inline __m128 F16_to_F32_4x(const __m128i &h) | |
{ | |
const __m128i mask_nosign = _mm_set1_epi32(0x7fff); | |
const __m128 magic_mult = _mm_castsi128_ps(_mm_set1_epi32((254 - 15) << 23)); | |
const __m128i was_infnan = _mm_set1_epi32(0x7bff); | |
const __m128 exp_infnan = _mm_castsi128_ps(_mm_set1_epi32(255 << 23)); | |
const __m128i was_nan = _mm_set1_epi32(0x7c00); | |
const __m128i nan_quiet = _mm_set1_epi32(1 << 22); | |
__m128i expmant = _mm_and_si128(mask_nosign, h); | |
__m128i justsign = _mm_xor_si128(h, expmant); | |
__m128i shifted = _mm_slli_epi32(expmant, 13); | |
__m128 scaled = _mm_mul_ps(_mm_castsi128_ps(shifted), magic_mult); | |
__m128i b_wasinfnan = _mm_cmpgt_epi32(expmant, was_infnan); | |
__m128i sign = _mm_slli_epi32(justsign, 16); | |
__m128 infnanexp = _mm_and_ps(_mm_castsi128_ps(b_wasinfnan), exp_infnan); | |
__m128i b_wasnan = _mm_cmpgt_epi32(expmant, was_nan); | |
__m128i nanquiet = _mm_and_si128(b_wasnan, nan_quiet); | |
__m128 infnandone = _mm_or_ps(infnanexp, _mm_castsi128_ps(nanquiet)); | |
__m128 sign_inf = _mm_or_ps(_mm_castsi128_ps(sign), infnandone); | |
__m128 result = _mm_or_ps(scaled, sign_inf); | |
return result; | |
} | |
// helper fn, only used for debug output | |
static uint32_t rrCtz32(uint32_t x) | |
{ | |
if (x == 0) | |
return 32; | |
// x is not 0, so this eventually terminates | |
uint32_t count = 0; | |
while ((x & 1) == 0) | |
{ | |
++count; | |
x >>= 1; | |
} | |
return x; | |
} | |
// tests against HW VCVTPS2PH which is our reference | |
static bool test_float_to_half_sse2() | |
{ | |
printf("float->half SSE2:\n"); | |
__m128i x = _mm_setr_epi32(0,1,2,3); | |
uint32_t base = 0; | |
do | |
{ | |
if ((base & 0xffffff) == 0) | |
printf("\r%02x", base >> 24); | |
__m128 x_flt = _mm_castsi128_ps(x); | |
__m128i ref16 = _mm_cvtps_ph(x_flt, 0); | |
__m128i ref32 = _mm_cvtepi16_epi32(ref16); // sign extend to 16 bits | |
__m128i test = F32_to_F16_4x(x_flt); | |
__m128i compare = _mm_cmpeq_epi32(ref32, test); | |
int match = _mm_movemask_epi8(compare); | |
if (match != 0xffff) | |
{ | |
uint32_t ref_res[4]; | |
uint32_t test_res[4]; | |
_mm_storeu_si128((__m128i *)ref_res, ref32); | |
_mm_storeu_si128((__m128i *)test_res, test); | |
int lane = rrCtz32(~match) / 4; | |
printf("\nmismatch!\n"); | |
printf("x=0x%08x ref=0x%04x test=0x%04x\n", base + lane, ref_res[lane] & 0xffff, test_res[lane] & 0xffff); | |
return false; | |
} | |
base += 4; | |
x = _mm_add_epi32(x, _mm_set1_epi32(4)); | |
} while (base != 0); | |
printf("\nall ok\n"); | |
return true; | |
} | |
// tests against HW VCVTPH2PS which is our reference | |
static bool test_half_to_float_sse2() | |
{ | |
printf("half->float SSE2:\n"); | |
__m128i x = _mm_setr_epi32(0,1,2,3); | |
for (int base = 0; base < 0x10000; base += 4) | |
{ | |
// sign extend as 16-bit values | |
__m128i sext16 = _mm_srai_epi32(_mm_slli_epi32(x, 16), 16); | |
// pack down to 16 bits | |
__m128i packed = _mm_packs_epi32(sext16, sext16); | |
// convert using HW instrs and SSE2 path | |
__m128 ref = _mm_cvtph_ps(packed); | |
__m128 test = F16_to_F32_4x(x); | |
// compare as ints | |
__m128i compare = _mm_cmpeq_epi32(_mm_castps_si128(ref), _mm_castps_si128(test)); | |
int match = _mm_movemask_epi8(compare); | |
if (match != 0xffff) | |
{ | |
uint32_t ref_res[4]; | |
uint32_t test_res[4]; | |
_mm_storeu_si128((__m128i *)ref_res, _mm_castps_si128(ref)); | |
_mm_storeu_si128((__m128i *)test_res, _mm_castps_si128(test)); | |
int lane = rrCtz32(~match) / 4; | |
printf("mismatch!\n"); | |
printf("x=0x%04x ref=0x%08x test=0x%08x\n", base + lane, ref_res[lane], test_res[lane]); | |
return false; | |
} | |
x = _mm_add_epi32(x, _mm_set1_epi32(4)); | |
} | |
printf("all ok!\n"); | |
return true; | |
} | |
int main() | |
{ | |
if (test_half_to_float_sse2() && | |
test_float_to_half_sse2()) | |
{ | |
return 0; | |
} | |
// Error occurred! | |
return 1; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment