Skip to content

Instantly share code, notes, and snippets.

@chriselrod
Last active July 24, 2020 09:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chriselrod/d5cbc1dc7f2427cea86080d33a0cf5dc to your computer and use it in GitHub Desktop.
Save chriselrod/d5cbc1dc7f2427cea86080d33a0cf5dc to your computer and use it in GitHub Desktop.
#= Function exp vectorized with AVX-512. KNL and SKX versions.
Copyright (C) 2014-2020 Free Software Foundation, Inc.
This file is part of the GNU C Library.
The GNU C Library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
The GNU C Library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with the GNU C Library; if not, see
<https://www.gnu.org/licenses/>. =#
using SIMDPirates
# Range reduction coefficients:
# log(2) inverted = 2^k/ln2
const __dbInvLn2 = reinterpret(Float64, 0x40971547652b82fe)
# right-shifter value = 3*2^52
const __dbShifter = reinterpret(Float64, 0x4338000000000000)
# log(2) high part = ln2/2^k(52-k-9 hibits)
const __dbLn2hi = reinterpret(Float64, 0x3f462e42fec00000)
# log(2) low part = ln2/2^k(52-k-9..104-k-9 lobits)
const __dbLn2lo = reinterpret(Float64, 0x3d5d1cf79abc9e3b)
# Polynomial coefficients (k=10, deg=3): */
const __dPC0 = reinterpret(Float64, 0x3ff0000000000000)
const __dPC1 = reinterpret(Float64, 0x3fe0000001ebfbe0)
const __dPC2 = reinterpret(Float64, 0x3fc5555555555556)
# Other constants:
# index mask = 2^k-1
const __lIndexMask = 0x00000000000003ff
# absolute value mask (SP)
const __iAbsMask = 0x7fffffff
const __iDomainRange = 0x4086232a % Int32
const J_TABLE = Float64[2^(big(j-1)/1024) for j in 1:1024];
# const J_TABLE_ptr = pointer(J_TABLE)
@inline function Base.reinterpret(::Type{T1}, x::SVec{W,T2}) where {W, T1 <: Union{Float64,Int64,UInt64}, T2 <: Union{Float64,Int64,UInt64}}
reinterpret(SVec{W,T1}, x)
end
@inline function gexp(x)
xint = reinterpret(UInt64, x)
# xshift = xint >> 32
dM = fma(x, __dbInvLn2, __dbShifter)
# xshift32 = vconvert(SVec{W,UInt32}, xshift)
dN = dM - __dbShifter
# iAbsX = xshift32 & __iAbsMask
dR = vfnmadd(dN, __dbLn2hi, x)
dR = vfnmadd(dN, __dbLn2lo, dR)
# @show dR
expr = fma(fma(fma(__dPC2, dR, __dPC1), dR, __dPC0), dR, __dPC0)
# iRangeMask = reinterpret(SVec{W,Int32}, iAbsX) > __iDomainRange
dMi = reinterpret(UInt64, dM)
# lIndex = SVec(gather(SIMDPirates.gep(pointer(J_TABLE), dMi & __lIndexMask), Val(false)))
lIndex = vload(stridedpointer(J_TABLE), (dMi & __lIndexMask,))
# @show lIndex
jR = lIndex * expr
# @show dN dM jR
lM = (dMi & (~__lIndexMask)) << 42
reti = lM + reinterpret(UInt64, jR)
reinterpret(Float64, reti)
end
### Example use
using VectorizationBase, SLEEFPirates
sx = SVec(ntuple(VectorizationBase.pick_vector_width_val(Float64)) do _ Core.VecElement(10randn(Float64)) end);
gexp(sx)
exp(sx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment