Last active
April 3, 2021 00:01
-
-
Save Roger-luo/0748a53b58c55e4187b545632917e54a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using YaoLocations | |
using BenchmarkTools | |
using BQCESubroutine | |
using StrideArrays | |
using BQCESubroutine.Utilities | |
using LoopVectorization | |
using ThreadingUtilities | |
using ArrayInterface | |
using BQCESubroutine.Utilities: BitSubspace | |
@inline function subspace_mul_kernel!(S, C_re, C_im, indices, U_re, U_im, s::Int, subspace, Bmax::Int, offset::Int) | |
k = subspace[s] | |
# idx = k .+ indices | |
# _k = k - 1 | |
for _b ∈ 0:(Bmax-1) >>> 3 | |
b = _b << 3; | |
bmax = b + 8 | |
if bmax ≤ Bmax # full static block | |
@avx for n ∈ 1:8, m ∈ axes(U_re, 1) | |
C_re_m_n = zero(T) | |
C_im_m_n = zero(T) | |
for i ∈ axes(U_re, 2) | |
j = k + indices[i] + offset | |
C_re_m_n += U_re[m,i] * S[1,n+b,j] - U_im[m,i] * S[2,n+b,j] | |
C_im_m_n += U_re[m,i] * S[2,n+b,j] + U_im[m,i] * S[1,n+b,j] | |
end | |
C_re[m,n] = C_re_m_n | |
C_im[m,n] = C_im_m_n | |
end | |
@avx for n ∈ 1:8, m ∈ axes(U_re, 1) | |
S_m = k + indices[m] + offset | |
S[1,n+b,S_m] = C_re[m,n] | |
S[2,n+b,S_m] = C_im[m,n] | |
end | |
# AmulB!(C_re, C_im, U_re, U_im, | |
else # dynamic block | |
Nmax = 8 + Bmax - bmax | |
@avx for n ∈ 1:Nmax, m ∈ axes(U_re, 1) | |
C_re_m_n = zero(T) | |
C_im_m_n = zero(T) | |
for i ∈ axes(U_re, 2) | |
j = k + indices[i] + offset | |
C_re_m_n += U_re[m,i] * S[1,n+b,j] - U_im[m,i] * S[2,n+b,j] | |
C_im_m_n += U_re[m,i] * S[2,n+b,j] + U_im[m,i] * S[1,n+b,j] | |
end | |
C_re[m,n] = C_re_m_n | |
C_im[m,n] = C_im_m_n | |
end | |
@avx for n ∈ 1:Nmax, m ∈ axes(U_re, 1) | |
S_m = k + indices[m] + offset | |
S[1,n+b,S_m] = C_re[m,n] | |
S[2,n+b,S_m] = C_im[m,n] | |
end | |
end | |
end | |
end | |
struct SubspaceMatrixMul{T, D} end | |
function (k::SubspaceMatrixMul{T, D})(p::Ptr{UInt}) where {T, D} | |
P = Tuple{Ptr{T}, Ptr{T}, Ptr{T}, Ptr{Int}, Ptr{BitSubspace}, Tuple{Int, Int}, UnitRange{Int}, Int} | |
_, (S_ptr, U_re_ptr, U_im_ptr, indices_ptr, subspace_ptr, S_size, thread_range, offset) = | |
ThreadingUtilities.load(p, P, 5*sizeof(UInt)) | |
S = StrideArray(PtrArray(S_ptr, (StaticInt{2}(), S_size...))) | |
U_re = StrideArray(PtrArray(U_re_ptr, (D, D))) | |
U_im = StrideArray(PtrArray(U_im_ptr, (D, D))) | |
indices = StrideArray(PtrArray(indices_ptr, (D, ))) | |
subspace = Base.unsafe_load(subspace_ptr) | |
C_re = StrideArray{T}(undef, (D, StaticInt{8}())) | |
C_im = StrideArray{T}(undef, (D, StaticInt{8}())) | |
for s in thread_range | |
subspace_mul_kernel!(S, C_re, C_im, indices, U_re, U_im, s, subspace, S_size[1], offset) | |
end | |
return | |
end | |
function subspace_mm_ptr(::Matrix{Complex{T}}, indices) where T | |
D = ArrayInterface.static_length(indices) | |
sig = SubspaceMatrixMul{T, D}() | |
@cfunction($sig, Cvoid, (Ptr{UInt}, )) | |
end | |
function setup_subspace_mm(p::Ptr{UInt}, S::Matrix{Complex{T}}, indices, U_re, U_im, subspace_ref::Ref{BitSubspace}, l::Int, f::Int, offset::Int) where T | |
D = StrideArrays.static_length(indices) | |
S_ptr = Base.unsafe_convert(Ptr{T}, S) | |
U_re_ptr = pointer(U_re) | |
U_im_ptr = pointer(U_im) | |
indices_ptr = Base.unsafe_convert(Ptr{Int}, indices) | |
subspace_ptr = Base.unsafe_convert(Ptr{BitSubspace}, subspace_ref) | |
fptr = subspace_mm_ptr(S, indices) | |
fptr_offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) | |
content = (S_ptr, U_re_ptr, U_im_ptr, indices_ptr, subspace_ptr, size(S), l:f, offset) | |
ThreadingUtilities.store!(p, content, fptr_offset) | |
return | |
end | |
function launch_subspace_mm(tid, S, indices, U_re, U_im, subspace, l, f, offset=0) | |
ThreadingUtilities.launch(setup_subspace_mm, tid, S, indices, U_re, U_im, subspace, l, f, offset) | |
end | |
function div_thread(tid, len, rem) | |
f = (tid - 1) * len + 1 | |
l = f + len - 1 | |
if rem > 0 | |
if tid <= rem | |
f = f + (tid - 1) | |
l = l + tid | |
else | |
f = f + rem | |
l = l + rem | |
end | |
end | |
return f, l | |
end | |
function threaded_subspace_mm!(S::Matrix{Complex{T}}, indices, U, subspace) where T | |
D = ArrayInterface.static_length(indices) | |
U_re = StrideArray{T}(undef, (D, D)) | |
U_im = StrideArray{T}(undef, (D, D)) | |
@inbounds @simd ivdep for i in 1:length(U) | |
U_re[i] = real(U[i]) | |
U_im[i] = imag(U[i]) | |
end | |
subspace_ref = Ref(subspace) | |
total = length(subspace) | |
nthreads = Threads.nthreads() - 1 | |
len, rem = divrem(total, nthreads) | |
GC.@preserve S U_re U_im subspace_ref begin | |
for tid in 1:nthreads | |
f, l = div_thread(tid, len, rem) | |
launch_subspace_mm(tid, S, indices, U_re, U_im, subspace_ref, f, l) | |
end | |
for tid in 1:nthreads | |
ThreadingUtilities.wait(tid) | |
end | |
end | |
return S | |
end | |
T = Float64 | |
n = 20 | |
S = rand(ComplexF64, 100, 1<<n); | |
U = rand(ComplexF64, 1<<3, 1<<3); | |
locs = Locations((1, 3, 5)) | |
subspace = bsubspace(n, locs) | |
comspace = bcomspace(n, locs) | |
indices = StrideArray{Int}(undef, (StaticInt{length(comspace)}(), )) | |
@simd ivdep for i in eachindex(indices) | |
indices[i] = comspace[i] + 1 | |
end | |
S1 = copy(S) | |
S1 = threaded_subspace_mm!(copy(S), indices, U, subspace) | |
S2 = BQCESubroutine.subspace_mul_generic!(copy(S), indices, U, subspace) | |
S1 ≈ S2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment