Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Last active April 3, 2021 00:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Roger-luo/0748a53b58c55e4187b545632917e54a to your computer and use it in GitHub Desktop.
Save Roger-luo/0748a53b58c55e4187b545632917e54a to your computer and use it in GitHub Desktop.
using YaoLocations
using BenchmarkTools
using BQCESubroutine
using StrideArrays
using BQCESubroutine.Utilities
using LoopVectorization
using ThreadingUtilities
using ArrayInterface
using BQCESubroutine.Utilities: BitSubspace
@inline function subspace_mul_kernel!(S, C_re, C_im, indices, U_re, U_im, s::Int, subspace, Bmax::Int, offset::Int)
k = subspace[s]
# idx = k .+ indices
# _k = k - 1
for _b ∈ 0:(Bmax-1) >>> 3
b = _b << 3;
bmax = b + 8
if bmax ≤ Bmax # full static block
@avx for n ∈ 1:8, m ∈ axes(U_re, 1)
C_re_m_n = zero(T)
C_im_m_n = zero(T)
for i ∈ axes(U_re, 2)
j = k + indices[i] + offset
C_re_m_n += U_re[m,i] * S[1,n+b,j] - U_im[m,i] * S[2,n+b,j]
C_im_m_n += U_re[m,i] * S[2,n+b,j] + U_im[m,i] * S[1,n+b,j]
end
C_re[m,n] = C_re_m_n
C_im[m,n] = C_im_m_n
end
@avx for n ∈ 1:8, m ∈ axes(U_re, 1)
S_m = k + indices[m] + offset
S[1,n+b,S_m] = C_re[m,n]
S[2,n+b,S_m] = C_im[m,n]
end
# AmulB!(C_re, C_im, U_re, U_im,
else # dynamic block
Nmax = 8 + Bmax - bmax
@avx for n ∈ 1:Nmax, m ∈ axes(U_re, 1)
C_re_m_n = zero(T)
C_im_m_n = zero(T)
for i ∈ axes(U_re, 2)
j = k + indices[i] + offset
C_re_m_n += U_re[m,i] * S[1,n+b,j] - U_im[m,i] * S[2,n+b,j]
C_im_m_n += U_re[m,i] * S[2,n+b,j] + U_im[m,i] * S[1,n+b,j]
end
C_re[m,n] = C_re_m_n
C_im[m,n] = C_im_m_n
end
@avx for n ∈ 1:Nmax, m ∈ axes(U_re, 1)
S_m = k + indices[m] + offset
S[1,n+b,S_m] = C_re[m,n]
S[2,n+b,S_m] = C_im[m,n]
end
end
end
end
struct SubspaceMatrixMul{T, D} end
function (k::SubspaceMatrixMul{T, D})(p::Ptr{UInt}) where {T, D}
P = Tuple{Ptr{T}, Ptr{T}, Ptr{T}, Ptr{Int}, Ptr{BitSubspace}, Tuple{Int, Int}, UnitRange{Int}, Int}
_, (S_ptr, U_re_ptr, U_im_ptr, indices_ptr, subspace_ptr, S_size, thread_range, offset) =
ThreadingUtilities.load(p, P, 5*sizeof(UInt))
S = StrideArray(PtrArray(S_ptr, (StaticInt{2}(), S_size...)))
U_re = StrideArray(PtrArray(U_re_ptr, (D, D)))
U_im = StrideArray(PtrArray(U_im_ptr, (D, D)))
indices = StrideArray(PtrArray(indices_ptr, (D, )))
subspace = Base.unsafe_load(subspace_ptr)
C_re = StrideArray{T}(undef, (D, StaticInt{8}()))
C_im = StrideArray{T}(undef, (D, StaticInt{8}()))
for s in thread_range
subspace_mul_kernel!(S, C_re, C_im, indices, U_re, U_im, s, subspace, S_size[1], offset)
end
return
end
function subspace_mm_ptr(::Matrix{Complex{T}}, indices) where T
D = ArrayInterface.static_length(indices)
sig = SubspaceMatrixMul{T, D}()
@cfunction($sig, Cvoid, (Ptr{UInt}, ))
end
function setup_subspace_mm(p::Ptr{UInt}, S::Matrix{Complex{T}}, indices, U_re, U_im, subspace_ref::Ref{BitSubspace}, l::Int, f::Int, offset::Int) where T
D = StrideArrays.static_length(indices)
S_ptr = Base.unsafe_convert(Ptr{T}, S)
U_re_ptr = pointer(U_re)
U_im_ptr = pointer(U_im)
indices_ptr = Base.unsafe_convert(Ptr{Int}, indices)
subspace_ptr = Base.unsafe_convert(Ptr{BitSubspace}, subspace_ref)
fptr = subspace_mm_ptr(S, indices)
fptr_offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt))
content = (S_ptr, U_re_ptr, U_im_ptr, indices_ptr, subspace_ptr, size(S), l:f, offset)
ThreadingUtilities.store!(p, content, fptr_offset)
return
end
function launch_subspace_mm(tid, S, indices, U_re, U_im, subspace, l, f, offset=0)
ThreadingUtilities.launch(setup_subspace_mm, tid, S, indices, U_re, U_im, subspace, l, f, offset)
end
function div_thread(tid, len, rem)
f = (tid - 1) * len + 1
l = f + len - 1
if rem > 0
if tid <= rem
f = f + (tid - 1)
l = l + tid
else
f = f + rem
l = l + rem
end
end
return f, l
end
function threaded_subspace_mm!(S::Matrix{Complex{T}}, indices, U, subspace) where T
D = ArrayInterface.static_length(indices)
U_re = StrideArray{T}(undef, (D, D))
U_im = StrideArray{T}(undef, (D, D))
@inbounds @simd ivdep for i in 1:length(U)
U_re[i] = real(U[i])
U_im[i] = imag(U[i])
end
subspace_ref = Ref(subspace)
total = length(subspace)
nthreads = Threads.nthreads() - 1
len, rem = divrem(total, nthreads)
GC.@preserve S U_re U_im subspace_ref begin
for tid in 1:nthreads
f, l = div_thread(tid, len, rem)
launch_subspace_mm(tid, S, indices, U_re, U_im, subspace_ref, f, l)
end
for tid in 1:nthreads
ThreadingUtilities.wait(tid)
end
end
return S
end
T = Float64
n = 20
S = rand(ComplexF64, 100, 1<<n);
U = rand(ComplexF64, 1<<3, 1<<3);
locs = Locations((1, 3, 5))
subspace = bsubspace(n, locs)
comspace = bcomspace(n, locs)
indices = StrideArray{Int}(undef, (StaticInt{length(comspace)}(), ))
@simd ivdep for i in eachindex(indices)
indices[i] = comspace[i] + 1
end
S1 = copy(S)
S1 = threaded_subspace_mm!(copy(S), indices, U, subspace)
S2 = BQCESubroutine.subspace_mul_generic!(copy(S), indices, U, subspace)
S1 ≈ S2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment