Skip to content

Instantly share code, notes, and snippets.

@ScottPJones
Created July 11, 2016 02:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ScottPJones/b37990dc77a4d1d38a94374a26695bbe to your computer and use it in GitHub Desktop.
Save ScottPJones/b37990dc77a4d1d38a94374a26695bbe to your computer and use it in GitHub Desktop.
module ArbFlts
import Base: +, -, /, *, sqrt, inv
using ArbFloats
export ArbFlt, ArbFlt160
export invsqrt
#=
type ArbRef
mid_exp::Int
mid_size::UInt
mid_d1::UInt
mid_d2::UInt
rad_exp::Int
rad_man::UInt
end
=#
# Note: this is currently pretty hard-coded to precision 160, 2 limbs, 64-bit only,
# just for this proof-of-concept, to see if the expected/hoped for performance gains
# could be achieved
const ArbMaxPrec = 160
const MaxLimbs = div(ArbMaxPrec+32,64)
typealias Limb UInt64
typealias _ArbMax ArbFloat{ArbMaxPrec}
typealias _VL Vector{Limb}
_arbinit() = ArbFloats.initializer(_ArbMax)
# Used to store result
immutable ArbRegisters
# These are initialized to hold Arb results
R::_ArbMax
S::_ArbMax
# These are set up to look like real Arb values (read-only)
X::_ArbMax
Y::_ArbMax
Z::_ArbMax
# These are used to store extended signficands
XV::_VL
YV::_VL
ZV::_VL
end
const ArbReg = ArbRegisters(_arbinit(),_arbinit(),_arbinit(),_arbinit(),_arbinit(),
_VL(MaxLimbs),_VL(MaxLimbs),_VL(MaxLimbs))
immutable ArbFlt{P,N}
mid_exp::Int32
rad_exp::Int32
rad_val::UInt32 # Contains sign & size_hint bits
mid_val::UInt32 # significand, if size_hint bit is 0, otherwise, least significant dword
mid_n::NTuple{N,UInt64}
end
typealias ArbFlt160 ArbFlt{160,2}
for s in (:X,:Y,:Z)
fS = Symbol("_copy",s,'!')
@eval @inline ($fS)(x::ArbFlt{160,2}) = _copy!(ArbReg.$s, ArbReg.$(Symbol(s,'V')), x)
end
function _copy!(z::_ArbMax, vec::Vector{Limb}, x::ArbFlt{160,2})
# Check for special values (0, NaN, +Inf, -Inf)
if (x.rad_val & UInt(2)) == UInt(0)
# Check for special values
if x.mid_val%Int32 >= 0
z.mid_d1 = z.mid_d2 = z.rad_man = z.mid_size =
if x.mid_val == ARB_VAL_BAL_INF
z.mid_exp = 0
z.rad_exp = ARF_EXP_POS_INF
else
z.mid_exp = fparb_to_arf_special(x)
z.rad_exp = 0
end
return z
end
z.mid_size = 2 | (x.rad_val & 1) # Get sign bit
z.mid_exp = x.mid_exp
z.mid_d1 = x.mid_val%UInt64 << 32
z.mid_d2 = 0%UInt64
return z
elseif x.mid_val == 0%UInt32 # up to 128 bits in-line (64-bit platform)
z.mid_d1 = x.mid_n[1]
z.mid_d2 = x.mid_n[2]
z.mid_size = (z.mid_d2 == UInt(0) ? 2 : 4) | (x.rad_val & 1) # Get sign bit
else
vec[1] = x.mid_val%UInt64 << 32
vec[2] = x.mid_n[1]
vec[3] = x.mid_n[2]
z.mid_d1 = 3
z.mid_d2 = reinterpret(UInt, pointer(vec))
z.mid_size = 6 | (x.rad_val & 1) # Get sign bit
end
z.mid_exp = x.mid_exp
z.rad_exp = x.rad_exp
z.rad_man = (x.rad_val>>2)%Int
z
end
# fmpz constants
const FLINT_BITS = 64
@static if FLINT_BITS != 64 ; error("Unimplimented for 32-bit") ; end
const COEFF_MAX = (1<<(FLINT_BITS-2)) - 1
const COEFF_MIN = -COEFF_MAX
const ARF_EXP_ZERO = 0
const ARF_EXP_NAN = COEFF_MIN
const ARF_EXP_POS_INF = (COEFF_MIN+1)
const ARF_EXP_NEG_INF = (COEFF_MIN+2)
# Constants for fixed precision Arb values
const ARB_VAL_ZERO = UInt32(0)
const ARB_VAL_NAN = UInt32(1)
const ARB_VAL_POS_INF = UInt32(2)
const ARB_VAL_NEG_INF = UInt32(3)
const ARB_VAL_BAL_INF = UInt32(4) # Indicates radius is infinite
const ARB_VAL_ONE = 0x80000000
# N = 2 means 256 bits total size (136 useful, plus 24 guard)
# N = 6 means 512 bits total size (392 useful, plus 24 guard)
# For our purposes, the exponent value must fit in 32 bits.
# On a 32-bit machine, it can store 30 bits, before needing to point to a an mpz structure.
# On a 64-bit machine, all valid exponents will always be inline, and can be tested as such.
# 01<30 0 bits> or 01<30 1 bits>
is_fmpz_ptr(v::Int) = (v >>> (FLINT_BITS-2)) == 1
arf_has_ptr(x::ArbFloat) = (x.mid_size > 5)
arf_is_special(x::ArbFloat) = (x.mid_size == 0)
@inline function get_exp(v::Int64)
(typemin(Int32) < v <= typemax(Int32)) || throw(InexactError())
v%Int32
end
@inline function arf_to_fparb_special(x)
x.rad_exp == ARF_EXP_POS_INF ? ARB_VAL_BAL_INF :
(x.mid_exp == 0 ? ARF_EXP_ZERO :
(x.mid_exp == ARF_EXP_POS_INF ? ARB_VAL_POS_INF :
(x.mid_exp == ARF_EXP_NEG_INF ? ARB_VAL_NEG_INF : ARB_VAL_NAN)))
end
@inline function fparb_to_arf_special(x)
const expval = (ARF_EXP_ZERO, ARF_EXP_NAN, ARF_EXP_POS_INF, ARF_EXP_NEG_INF)
expval[x.mid_val+1]
end
function ArbFlt{P}(x::ArbFloat{P})
P != 160 && error("unsupported precision")
const N = 2
arf_is_special(x) &&
return ArbFlt{P,N}(0,0,0%UInt32,arf_to_fparb_special(x),(0%UInt64,0%UInt64))
# Check exponent sizes
me = get_exp(x.mid_exp)
re = get_exp(x.rad_exp)
# Create ArbFlts based on Arb significand size
sz = x.mid_size >> 1
rm = (x.rad_man%UInt32 << 2) | (x.mid_size%UInt32 & 0x00001)
if sz > 1
if sz > 2
# Need pointer
ptr = reinterpret(Ptr{UInt64},x.mid_d2)
val = unsafe_load(ptr)
val%UInt32 == 0x00000 || throw("too large")
ArbFlt{P,N}(me, re, rm | 0x00002,
(unsafe_load(ptr)>>>32)%UInt32,
(unsafe_load(ptr,2), unsafe_load(ptr,3)))
else
ArbFlt{P,N}(me, re, rm | 0x00002,0%UInt32, (x.mid_d1%UInt64, x.mid_d2%UInt64))
end
elseif x.mid_d1%UInt32 != UInt32(0) # 64-bit significand
ArbFlt{P,N}(me, re, rm | 0x00002,
0%UInt32, tuple(x.mid_d1%UInt64, 0%UInt64))
else # 32-bit significand
ArbFlt{P,N}(me, re, rm, (x.mid_d1>>>32)%UInt32,(0%UInt64,0%UInt64))
end
end
function ArbFloat{P,N}(x::ArbFlt{P,N})
P > 416 && error("Too large precision")
N != (P > 160 ? 6 : 2) && error("Invalid $N for precision $P")
z = ArbFloats.initializer(ArbFloat{P})
# Check for special values (0, NaN, +Inf, -Inf)
if (x.rad_val & 2) == 0
# Check for special values
if x.mid_val%Int32 >= 0
z.mid_d1 = z.mid_d2 = z.rad_man = z.mid_size = 0
if x.mid_val == ARB_VAL_BAL_INF
z.mid_exp = 0
z.rad_exp = ARF_EXP_POS_INF
else
z.mid_exp = fparb_to_arf_special(x.mid_val)
z.rad_exp = 0
end
return z
end
z.mid_size = 2 | (x.rad_val & 1) # Get sign bit
z.mid_exp = x.mid_exp
z.mid_d1 = x.mid_val%UInt64 << 32
z.mid_d2 = 0%UInt64
return z
elseif x.mid_val == 0%UInt32 # up to 128 bits in-line (64-bit platform)
z.mid_d1 = x.mid_n[1]
z.mid_d2 = x.mid_n[2]
z.mid_size = (z.mid_d2 == UInt(0) ? 2 : 4) | (x.rad_val & 1) # Get sign bit
else
throw("Unimplemented, need to allocate pointer")
ptr = reinterpret(Ptr{UInt64},z.mid_d2)
unsafe_store!(ptr, x.mid_val%UInt64 << 32)
unsafe_store!(ptr, x.mid_n[1],2)
unsafe_store!(ptr, x.mid_n[2],3)
z.mid_d1 = 3
z.mid_size = 6 | (x.rad_val & 1) # Get sign bit
end
z.mid_exp = x.mid_exp
z.rad_exp = x.rad_exp
z.rad_man = (x.rad_val>>2)%Int
z
end
for (op,cfunc) in ((:inv, :arb_inv), (:sqrt, :arb_sqrt), (:invsqrt, :arb_rsqrt))
@eval begin
function ($op){P}(x::ArbFlt{P,2})
ccall(ArbFloats.@libarb($cfunc), Void, (Ptr{ArbFloat}, Ref{ArbFloat}, Int),
&ArbReg.R, _copy!(ArbReg.X, ArbReg.XV, x), P)
ArbFlt(ArbReg.R)
end
end
end
for (op,cfunc) in ((:+,:arb_add), (:-, :arb_sub), (:*, :arb_mul), (:/, :arb_div))
@eval begin
function ($op){P}(x::ArbFlt{P,2}, y::ArbFlt{P,2})
ccall(ArbFloats.@libarb($cfunc), Void,
(Ptr{ArbFloat}, Ref{ArbFloat}, Ref{ArbFloat}, Int),
&ArbReg.R, _copy!(ArbReg.X, ArbReg.XV, x), _copy!(ArbReg.Y, ArbReg.YV, y), P)
ArbFlt(ArbReg.R)
end
end
end
ArbFloats.iszero(x::ArbFlt) = ((x.rad_val >> 1) | x.mid_val) == 0
ArbFloats.isexact(x::ArbFlt) = (x.rad_val >> 1) == 0
# These only work for 256 bit ArbFlts currently
Base.zero{P}(::Type{ArbFlt{P,2}}) =
ArbFlt{P,2}(0,0,0%UInt32,ARB_VAL_ZERO,(0%UInt64,0%UInt64))
Base.one{P}(::Type{ArbFlt{P,2}}) =
ArbFlt{P,2}(1,0,0%UInt32,ARB_VAL_ONE,(0%UInt64,0%UInt64))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment