Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created March 21, 2012 05:20
Show Gist options
  • Save rygorous/2144712 to your computer and use it in GitHub Desktop.
Save rygorous/2144712 to your computer and use it in GitHub Desktop.
half->float variants
// half->float variants.
// by Fabian "ryg" Giesen.
//
// I hereby place this code in the public domain.
//
// half_to_float_fast: table based
// tables could be done in a more compact fashion (in particular, can store tab2 in low word of tab1!)
// but something of a dead end since not very SIMD-friendly. pretty much abandoned at this point.
//
// half_to_float_fast2: use FP adder hardware to deal with denormals.
// now this one has potential! (but needs some polish)
//
// half_to_float_fast3: same, but without bitfields. low number of constants involved.
// looking pretty good. need to check how it comes out in SSE2.
//
// half_to_float_fast4: again the same, but written to be easy to translate to SSE2
// intrinsics.
//
// half_to_float4_SSE2: initial port of half_to_float_fast4 to SSE2. should do a second
// pass + clean-up, but not today. :)
//
// half_to_float4b_SSE2: some tweaks. should generate less ops and needs less constants;
// about 11% faster on my Sandy Bridge i7 using VS2010. YMMV.
//
// half_to_floast_fast5: slightly different approach, turns FP16 denormals into FP32 denormals.
// it's very slick and short but will be slower if denormals actually occur.
//
// half_to_float5_SSE2: SSE2-ified version of "fast5" variant. as said, in the presence of
// denormals, this will be noticably slower than variants 4/4b. use the included benchmarking
// code to find out by how much :). it's kinda serial, which means that even though it has a lot
// less instructions than variants 4/4b, it's not all that much faster even in the best case.
#include <stdio.h>
#include <emmintrin.h>
#include <intrin.h>
typedef unsigned int uint;
union FP32
{
uint u;
float f;
struct
{
uint Mantissa : 23;
uint Exponent : 8;
uint Sign : 1;
};
};
union FP16
{
unsigned short u;
struct
{
uint Mantissa : 10;
uint Exponent : 5;
uint Sign : 1;
};
};
static FP32 half_to_float_full(FP16 h)
{
FP32 o = { 0 };
// From ISPC ref code
if (h.Exponent == 0 && h.Mantissa == 0) // (Signed) zero
o.Sign = h.Sign;
else
{
if (h.Exponent == 0) // Denormal (will convert to normalized)
{
// Adjust mantissa so it's normalized (and keep track of exp adjust)
int e = -1;
uint m = h.Mantissa;
do
{
e++;
m <<= 1;
} while ((m & 0x400) == 0);
o.Mantissa = (m & 0x3ff) << 13;
o.Exponent = 127 - 15 - e;
o.Sign = h.Sign;
}
else if (h.Exponent == 0x1f) // Inf/NaN
{
// NOTE: It's safe to treat both with the same code path by just truncating
// lower Mantissa bits in NaNs (this is valid).
o.Mantissa = h.Mantissa << 13;
o.Exponent = 255;
o.Sign = h.Sign;
}
else // Normalized number
{
o.Mantissa = h.Mantissa << 13;
o.Exponent = 127 - 15 + h.Exponent;
o.Sign = h.Sign;
}
}
return o;
}
// Conversion tables
static uint tab1[256], tab2[256], tab3[256];
static void init_tables()
{
FP16 f16;
FP32 f32;
for (int i=0; i < 256; i++)
{
f16.u = i << 8;
f32 = half_to_float_full(f16);
tab1[i] = f32.u;
tab2[i] = 1 << 13;
f16.u = i;
f32 = half_to_float_full(f16);
tab3[i] = f32.u;
}
// Lower exponent end has some denormals
tab2[0x03] = 1 << 14;
tab2[0x02] = 1 << 14;
tab2[0x01] = 1 << 15;
tab2[0x83] = 1 << 14;
tab2[0x82] = 1 << 14;
tab2[0x81] = 1 << 15;
}
static FP32 half_to_float_fast(FP16 h)
{
FP32 o;
if (h.u & 0x7f00)
o.u = tab1[h.u >> 8] + tab2[h.u >> 8] * (h.u & 0xff);
else
o.u = ((h.u & 0x8000) << 16) | tab3[h.u & 0xff];
return o;
}
static FP32 half_to_float_fast2(FP16 h)
{
static const FP32 magic = { 126 << 23 };
FP32 o;
if (h.Exponent == 0) // Zero / Denormal
{
o.u = magic.u + h.Mantissa;
o.f -= magic.f;
}
else
{
o.Mantissa = h.Mantissa << 13;
if (h.Exponent == 0x1f) // Inf/NaN
o.Exponent = 255;
else
o.Exponent = 127 - 15 + h.Exponent;
}
o.Sign = h.Sign;
return o;
}
static FP32 half_to_float_fast3(FP16 h)
{
static const FP32 magic = { 113 << 23 };
static const uint shifted_exp = 0x7c00 << 13; // exponent mask after shift
FP32 o;
// mantissa+exponent
uint shifted = (h.u & 0x7fff) << 13;
uint exponent = shifted & shifted_exp;
// exponent cases
o.u = shifted;
if (exponent == 0) // Zero / Denormal
{
o.u += magic.u;
o.f -= magic.f;
}
else if (exponent == shifted_exp) // Inf/NaN
o.u += (255 - 31) << 23;
else
o.u += (127 - 15) << 23;
o.u |= (h.u & 0x8000) << 16; // copy sign bit
return o;
}
static FP32 half_to_float_fast4(FP16 h)
{
static const FP32 magic = { 113 << 23 };
static const uint shifted_exp = 0x7c00 << 13; // exponent mask after shift
FP32 o;
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
uint exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
else if (exp == 0) // Zero/Denormal?
{
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}
o.u |= (h.u & 0x8000) << 16; // sign bit
return o;
}
static FP32 half_to_float_fast5(FP16 h)
{
static const FP32 magic = { (254 - 15) << 23 };
static const FP32 was_infnan = { (127 + 16) << 23 };
FP32 o;
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
o.f *= magic.f; // exponent adjust
if (o.f >= was_infnan.f) // make sure Inf/NaN survive
o.u |= 255 << 23;
o.u |= (h.u & 0x8000) << 16; // sign bit
return o;
}
static __m128 half_to_float4_SSE2(__m128i h)
{
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) }
#define CONST(name) *(const __m128i *)&name
SSE_CONST4(mask_nosign, 0x7fff);
SSE_CONST4(mask_justsign, 0x8000);
SSE_CONST4(mask_shifted_exp, 0x7c00 << 13);
SSE_CONST4(expadjust_normal, (127 - 15) << 23);
SSE_CONST4(expadjust_infnan, (128 - 16) << 23);
SSE_CONST4(expadjust_denorm, 1 << 23);
SSE_CONST4(magic_denorm, 113 << 23);
__m128i mnosign = CONST(mask_nosign);
__m128i expmant = _mm_and_si128(mnosign, h);
__m128i justsign = _mm_and_si128(h, CONST(mask_justsign));
__m128i mshiftexp = CONST(mask_shifted_exp);
__m128i eadjust = CONST(expadjust_normal);
__m128i shifted = _mm_slli_epi32(expmant, 13);
__m128i adjusted = _mm_add_epi32(eadjust, shifted);
__m128i justexp = _mm_and_si128(shifted, mshiftexp);
__m128i zero = _mm_setzero_si128();
__m128i b_isinfnan = _mm_cmpeq_epi32(mshiftexp, justexp);
__m128i b_isdenorm = _mm_cmpeq_epi32(zero, justexp);
__m128i adj_infnan = _mm_and_si128(b_isinfnan, CONST(expadjust_infnan));
__m128i adjusted2 = _mm_add_epi32(adjusted, adj_infnan);
__m128i adj_den = CONST(expadjust_denorm);
__m128i den1 = _mm_add_epi32(adj_den, adjusted2);
__m128 den2 = _mm_sub_ps(_mm_castsi128_ps(den1), *(const __m128 *)&magic_denorm);
__m128 adjusted3 = _mm_and_ps(den2, _mm_castsi128_ps(b_isdenorm));
__m128 adjusted4 = _mm_andnot_ps(_mm_castsi128_ps(b_isdenorm), _mm_castsi128_ps(adjusted2));
__m128 adjusted5 = _mm_or_ps(adjusted3, adjusted4);
__m128i sign = _mm_slli_epi32(justsign, 16);
__m128 final = _mm_or_ps(adjusted5, _mm_castsi128_ps(sign));
// ~21 SSE2 ops.
return final;
#undef SSE_CONST4
#undef CONST
}
static __m128 half_to_float4b_SSE2(__m128i h)
{
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) }
#define CONST(name) *(const __m128i *)&name
SSE_CONST4(mask_nosign, 0x7fff);
SSE_CONST4(smallest_normal, 0x0400);
SSE_CONST4(infinity, 0x7c00);
SSE_CONST4(expadjust_normal, (127 - 15) << 23);
SSE_CONST4(magic_denorm, 113 << 23);
__m128i mnosign = CONST(mask_nosign);
__m128i eadjust = CONST(expadjust_normal);
__m128i smallest = CONST(smallest_normal);
__m128i infty = CONST(infinity);
__m128i expmant = _mm_and_si128(mnosign, h);
__m128i justsign = _mm_xor_si128(h, expmant);
__m128i b_notinfnan = _mm_cmpgt_epi32(infty, expmant);
__m128i b_isdenorm = _mm_cmpgt_epi32(smallest, expmant);
__m128i shifted = _mm_slli_epi32(expmant, 13);
__m128i adj_infnan = _mm_andnot_si128(b_notinfnan, eadjust);
__m128i adjusted = _mm_add_epi32(eadjust, shifted);
__m128i den1 = _mm_add_epi32(shifted, CONST(magic_denorm));
__m128i adjusted2 = _mm_add_epi32(adjusted, adj_infnan);
__m128 den2 = _mm_sub_ps(_mm_castsi128_ps(den1), *(const __m128 *)&magic_denorm);
__m128 adjusted3 = _mm_and_ps(den2, _mm_castsi128_ps(b_isdenorm));
__m128 adjusted4 = _mm_andnot_ps(_mm_castsi128_ps(b_isdenorm), _mm_castsi128_ps(adjusted2));
__m128 adjusted5 = _mm_or_ps(adjusted3, adjusted4);
__m128i sign = _mm_slli_epi32(justsign, 16);
__m128 final = _mm_or_ps(adjusted5, _mm_castsi128_ps(sign));
// ~19 SSE2 ops.
return final;
#undef SSE_CONST4
#undef CONST
}
static __m128 half_to_float5_SSE2(__m128i h)
{
#define SSE_CONST4(name, val) static const __declspec(align(16)) uint name[4] = { (val), (val), (val), (val) }
#define CONST(name) *(const __m128i *)&name
#define CONSTF(name) *(const __m128 *)&name
SSE_CONST4(mask_nosign, 0x7fff);
SSE_CONST4(magic, (254 - 15) << 23);
SSE_CONST4(was_infnan, 0x7bff);
SSE_CONST4(exp_infnan, 255 << 23);
__m128i mnosign = CONST(mask_nosign);
__m128i expmant = _mm_and_si128(mnosign, h);
__m128i justsign = _mm_xor_si128(h, expmant);
__m128i expmant2 = expmant; // copy (just here for counting purposes)
__m128i shifted = _mm_slli_epi32(expmant, 13);
__m128 scaled = _mm_mul_ps(_mm_castsi128_ps(shifted), *(const __m128 *)&magic);
__m128i b_wasinfnan = _mm_cmpgt_epi32(expmant2, CONST(was_infnan));
__m128i sign = _mm_slli_epi32(justsign, 16);
__m128 infnanexp = _mm_and_ps(_mm_castsi128_ps(b_wasinfnan), CONSTF(exp_infnan));
__m128 sign_inf = _mm_or_ps(_mm_castsi128_ps(sign), infnanexp);
__m128 final = _mm_or_ps(scaled, sign_inf);
// ~11 SSE2 ops.
return final;
#undef SSE_CONST4
#undef CONST
#undef CONSTF
}
// make sure we don't get DCE on SSE code
__declspec(align(16)) float output[1024*4];
int main(int argc, char **argv)
{
FP16 h;
FP32 full, fast, fast2, fast3, fast4, fast5;
init_tables();
for (int i=0; i < 0x10000; i++)
{
h.u = i;
full = half_to_float_full(h);
fast = half_to_float_fast(h);
fast2 = half_to_float_fast2(h);
fast3 = half_to_float_fast3(h);
fast4 = half_to_float_fast4(h);
fast5 = half_to_float_fast5(h);
if (full.u != fast.u || full.u != fast2.u || full.u != fast3.u || full.u != fast4.u || full.u != fast5.u)
{
printf("mismatch! val=%04x full=%08x fast=%08x fast2=%08x fast3=%08x fast4=%08x fast5=%08x\n", i, full.u, fast.u, fast2.u, fast3.u, fast4.u, fast5.u);
return 1;
}
}
for (int i=0; i < 0x10000; i += 4)
{
uint ref[4];
uint ssein[4], sseout[4];
for (int j=0; j < 4; j++)
{
ssein[j] = i + j;
h.u = i + j;
full = half_to_float_full(h);
ref[j] = full.u;
}
__m128i in = _mm_loadu_si128((const __m128i *)ssein);
__m128 out = half_to_float4b_SSE2(in);
_mm_storeu_ps((float *)sseout, out);
for (int j=0; j < 4; j++)
{
if (sseout[j] != ref[j])
{
printf("mismatch! val=%04x full=%08x fast4SSE2=%08x\n", i+j, ref[j], sseout[j]);
return 1;
}
}
}
uint best = ~0u;
int start = 0, end = 0x10000;
for (int runs=0; runs < 15000; runs++)
{
__m128i vals = _mm_set_epi32(start + 3, start + 2, start + 1, start + 0);
__m128i incr = _mm_set1_epi32(4);
uint tstart = (uint) __rdtsc();
for (int i=start; i < end; i += 4)
{
__m128 out = half_to_float4b_SSE2(vals);
_mm_store_ps(&output[i & 1023], out);
vals = _mm_add_epi32(vals, incr);
}
uint time = (uint) __rdtsc() - tstart;
if (time < best)
best = time;
}
printf("best: %d cycles = %.2f / vec\n", best, 4.0f * best / (end - start));
printf("all ok.\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment