Skip to content

Instantly share code, notes, and snippets.

@Keno
Created June 28, 2019 22:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Keno/0bf23c7d0c75cfdab84580039d54997e to your computer and use it in GitHub Desktop.
Save Keno/0bf23c7d0c75cfdab84580039d54997e to your computer and use it in GitHub Desktop.
module Branchy
using Base: significand_mask, significand_bits, exponent_mask, sign_mask
@inline function compute_base_shift(e)
if e < -25 # Very small numbers map to zero
base = 0x0000
sh = 25
elseif e < -14 # Small numbers map to denorms
base = 0x0000
sh = -e-1
elseif e <= 15 # Normal numbers just lose precision
base = ((e+15)<<10)
sh = 13
elseif e < 128 # Large numbers map to Infinity
base = 0x7C00
sh = 24
else # Infinity and NaN's stay Infinity and NaN's
base = 0x7C00
sh = 13
end
(base, sh)
end
function round(val::Float32)
f = reinterpret(UInt32, val)
i = ((f & exponent_mask(Float32)) >> significand_bits(Float32))
e = i - 127
(base, sh) = compute_base_shift(e)
sig = f & significand_mask(Float32)
# If `val` is subnormal, the tables are set up to force the
# result to 0, so the significand has an implicit `1` in the
# cases we care about.
sig |= significand_mask(Float32) + 0x1
h = (base + (sig >> sh) & significand_mask(Float16)) % UInt16 |
UInt16((f & sign_mask(Float32)) >> 16)
# round
nextbit = (sig >> (sh-1)) & 1
if !isnan(val) & nextbit != 0
# Round halfway to even or check lower bits
if h&1 == 1 || (f & ((1<<(sh-1))-1)) != 0
h += UInt16(1)
end
end
reinterpret(Float16, h)
end
end
module Current
using Base: significand_mask, significand_bits, exponent_mask, sign_mask
let _basetable = Vector{UInt16}(undef, 512),
_shifttable = Vector{UInt8}(undef, 512)
for i = 0:255
e = i - 127
if e < -25 # Very small numbers map to zero
_basetable[i|0x000+1] = 0x0000
_basetable[i|0x100+1] = 0x8000
_shifttable[i|0x000+1] = 25
_shifttable[i|0x100+1] = 25
elseif e < -14 # Small numbers map to denorms
_basetable[i|0x000+1] = 0x0000
_basetable[i|0x100+1] = 0x8000
_shifttable[i|0x000+1] = -e-1
_shifttable[i|0x100+1] = -e-1
elseif e <= 15 # Normal numbers just lose precision
_basetable[i|0x000+1] = ((e+15)<<10)
_basetable[i|0x100+1] = ((e+15)<<10) | 0x8000
_shifttable[i|0x000+1] = 13
_shifttable[i|0x100+1] = 13
elseif e < 128 # Large numbers map to Infinity
_basetable[i|0x000+1] = 0x7C00
_basetable[i|0x100+1] = 0xFC00
_shifttable[i|0x000+1] = 24
_shifttable[i|0x100+1] = 24
else # Infinity and NaN's stay Infinity and NaN's
_basetable[i|0x000+1] = 0x7C00
_basetable[i|0x100+1] = 0xFC00
_shifttable[i|0x000+1] = 13
_shifttable[i|0x100+1] = 13
end
end
global const shifttable = (_shifttable...,)
global const basetable = (_basetable...,)
end
function round(val::Float32)
f = reinterpret(UInt32, val)
if isnan(val)
t = 0x8000 ⊻ (0x8000 & ((f >> 0x10) % UInt16))
return reinterpret(Float16, t ⊻ ((f >> 0xd) % UInt16))
end
i = ((f & ~significand_mask(Float32)) >> significand_bits(Float32)) + 1
@inbounds sh = shifttable[i]
f &= significand_mask(Float32)
# If `val` is subnormal, the tables are set up to force the
# result to 0, so the significand has an implicit `1` in the
# cases we care about.
f |= significand_mask(Float32) + 0x1
@inbounds h = (basetable[i] + (f >> sh) & significand_mask(Float16)) % UInt16
# round
# NOTE: we maybe should ignore NaNs here, but the payload is
# getting truncated anyway so "rounding" it might not matter
nextbit = (f >> (sh-1)) & 1
if nextbit != 0
# Round halfway to even or check lower bits
if h&1 == 1 || (f & ((1<<(sh-1))-1)) != 0
h += UInt16(1)
end
end
reinterpret(Float16, h)
end
end
module New
using Base: significand_mask, significand_bits, exponent_mask, sign_mask
const IS_MEMORY_TABLE_FAST = !(Sys.ARCH in (:x86, :x86_64))
@inline function compute_base_shift(e)
if e < -25 # Very small numbers map to zero
base = 0x0000
sh = 25
elseif e < -14 # Small numbers map to denorms
base = 0x0000
sh = -e-1
elseif e <= 15 # Normal numbers just lose precision
base = ((e+15)<<10)
sh = 13
elseif e < 128 # Large numbers map to Infinity
base = 0x7C00
sh = 24
else # Infinity and NaN's stay Infinity and NaN's
base = 0x7C00
sh = 13
end
(base, sh)
end
function lookup_base_sh end
if IS_MEMORY_TABLE_FAST
let _basetable = Vector{UInt16}(undef, 512),
_shifttable = Vector{UInt8}(undef, 512)
for i = 0:255
e = i - 127
(base, sh) = compute_base_shift(e)
_basetable[i|0x000+1] = base
_basetable[i|0x100+1] = base | 0x8000
_shifttable[i|0x000+1] = sh
_shifttable[i|0x100+1] = sh
end
global const shifttable = (_shifttable...,)
global const basetable = (_basetable...,)
global lookup_base_sh
@inline lookup_base_sh(es) =
(@inbounds basetable[es + 1],
@inbounds shifttable[es + 1])
end
else
function lookup_base_sh(es)
(base, sh) = compute_base_shift((es & 0xff) - 127)
base |= (es & 0x100) << (significand_bits(Float32) - 16)
(base, sh)
end
end
function round(val::Float32)
f = reinterpret(UInt32, val)
es = ((f & ~significand_mask(Float32)) >> significand_bits(Float32))
(base, sh) = lookup_base_sh(es)
f &= significand_mask(Float32)
# If `val` is subnormal, the tables are set up to force the
# result to 0, so the significand has an implicit `1` in the
# cases we care about.
f |= significand_mask(Float32) + 0x1
h = (base + (f >> sh) & significand_mask(Float16)) % UInt16
# round
# NOTE: we maybe should ignore NaNs here, but the payload is
# getting truncated anyway so "rounding" it might not matter
nextbit = (f >> (sh-1)) & 1
if !isnan(val) & nextbit != 0
# Round halfway to even or check lower bits
if h&1 == 1 || (f & ((1<<(sh-1))-1)) != 0
h += UInt16(1)
end
end
reinterpret(Float16, h)
end
end
A = Float32.(reinterpret.(Float16, typemin(UInt16):typemax(UInt16)));
B = Float32.(randn(65536));
C = Float32.(reinterpret(Float16, rand(UInt16, 65536)));
@benchmark Current.round.(A)
@benchmark Current.round.(B)
@benchmark Current.round.(C)
@benchmark Branchy.round.(A)
@benchmark Branchy.round.(B)
@benchmark Branchy.round.(C)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment