Last active
March 17, 2021 19:49
-
-
Save stillyslalom/2e7e0a5483a847ba3206f66dc3a740fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra, LoopVectorization, Test | |
function f1(M, G, J, H, A, B, ϕ) | |
for mm = 1:M | |
tempmatr = A \ ( | |
reshape( | |
permutedims( | |
ϕ[:, 2:G, :, M + 2 - mm], | |
[2 1 3]), | |
G - 1, :) | |
+ | |
reshape( | |
permutedims( | |
B[:, :, :, M + 2 - mm], | |
[2 1 3]), | |
G - 1, :) | |
) | |
ϕ[:, 2:G, :, M + 1 - mm] = permutedims( | |
reshape( | |
tempmatr, | |
G - 1, | |
2 * J + 1, | |
H + 1), | |
[2 1 3] | |
) | |
end | |
end | |
function f2(M, G, J, H, A, B, ϕ) | |
tmp = similar(ϕ, G-1, (2*J+1)*(H + 1)) | |
 = factorize(A) | |
for mm = 1:M | |
mp1 = M + 1 - mm | |
mp2 = mp1 + 1 | |
@inbounds for hh = 1:H+1 | |
for jj = 1:2*J+1 | |
for gg = 1:G-1 | |
tmp[gg, jj + (hh - 1)*(2*J + 1)] = ϕ[jj, gg+1, hh, mp2] + | |
B[jj, gg, hh, mp2] | |
end | |
end | |
end | |
ldiv!(Â, tmp) | |
@inbounds for hh = 1:H+1 | |
for jj = 1:2*J+1 | |
for gg = 1:G-1 | |
ϕ[jj, gg + 1, hh, mp1] = tmp[gg, jj + (hh - 1)*(2*J + 1)] | |
end | |
end | |
end | |
end | |
end | |
function f3(M, G, J, H, A, B, ϕ) | |
tmp = similar(ϕ, G-1, (2*J+1)*(H + 1)) | |
LU = factorize(A) | |
for mm = 1:M | |
mp1 = M + 1 - mm | |
mp2 = mp1 + 1 | |
@inbounds Threads.@threads for hh = 1:H+1 | |
for jj = 1:2*J+1 | |
for gg = 1:G-1 | |
tmp[gg, jj + (hh - 1)*(2*J + 1)] = ϕ[jj, gg+1, hh, mp2] + | |
B[jj, gg, hh, mp2] | |
end | |
end | |
end | |
ldiv!(LU, tmp) | |
@inbounds Threads.@threads for hh = 1:H+1 | |
for jj = 1:2*J+1 | |
for gg = 1:G-1 | |
ϕ[jj, gg + 1, hh, mp1] = tmp[gg, jj + (hh - 1)*(2*J + 1)] | |
end | |
end | |
end | |
end | |
end | |
function f4(M, G, J, H, A, B, ϕ) | |
 = factorize(A) | |
Âs = [deepcopy(Â) for i = 1:Threads.nthreads()] | |
tmp = [similar(ϕ, G-1) for i = 1:Threads.nthreads()] # thread-local storage | |
Bf = reinterpret(reshape, Float64, B) | |
ϕf = reinterpret(reshape, Float64, ϕ) | |
tmpf = reinterpret.(reshape, Float64, tmp) | |
jmax = 2*J + 1 | |
@inbounds for mm = 1:M | |
mp1 = M + 1 - mm | |
mp2 = mp1 + 1 | |
Threads.@threads for hh = 1:H+1 | |
tmpl = tmp[Threads.threadid()] | |
tmpfl = tmpf[Threads.threadid()] | |
Âl = Âs[Threads.threadid()] | |
h_idx = (hh - 1)*jmax | |
for jj = 1:jmax | |
@avx for gg = 1:G-1 | |
tmpfl[1, gg] = ϕf[1, jj, gg+1, hh, mp2] + | |
Bf[1, jj, gg, hh, mp2] | |
tmpfl[2, gg] = ϕf[2, jj, gg+1, hh, mp2] + | |
Bf[2, jj, gg, hh, mp2] | |
end | |
ldiv!(Âl, tmpl) | |
@avx for gg = 1:G-1 | |
ϕf[1, jj, gg + 1, hh, mp1] = tmpfl[1, gg] | |
ϕf[2, jj, gg + 1, hh, mp1] = tmpfl[2, gg] | |
end | |
end | |
end | |
end | |
end | |
function f5(M, G, J, H, A, B, ϕ) | |
 = factorize(A) | |
Âs = [deepcopy(Â) for i = 1:Threads.nthreads()] | |
tmp = [similar(ϕ, G-1) for i = 1:Threads.nthreads()] # thread-local storage | |
jmax = 2*J + 1 | |
@inbounds for mm = 1:M | |
mp1 = M + 1 - mm | |
mp2 = mp1 + 1 | |
Threads.@threads for hh = 1:H+1 | |
tmpl = tmp[Threads.threadid()] | |
Âl = Âs[Threads.threadid()] | |
h_idx = (hh - 1)*jmax | |
@inbounds for jj = 1:jmax | |
for gg = 1:G-1 | |
tmpl[gg] = ϕ[jj, gg+1, hh, mp2] + | |
B[jj, gg, hh, mp2] | |
end | |
ldiv!(Âl, tmpl) | |
for gg = 1:G-1 | |
ϕ[jj, gg + 1, hh, mp1] = tmpl[gg] | |
end | |
end | |
end | |
end | |
end | |
M = 10 | |
G = 50 | |
J = 50 | |
H = 30 | |
# This was faster without sparsity of A, not sure why | |
A = Matrix(Tridiagonal(rand(G-1,G-1))); | |
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1); | |
ϕ1 = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1); | |
ϕ2, ϕ3, ϕ4, ϕ5 = copy(ϕ1), copy(ϕ1), copy(ϕ1), copy(ϕ1) | |
@testset "Strang splitting optimization" begin | |
f1(M, G, J, H, A, B, ϕ1) | |
f2(M, G, J, H, A, B, ϕ2) | |
f3(M, G, J, H, A, B, ϕ3) | |
f4(M, G, J, H, A, B, ϕ4) | |
f5(M, G, J, H, A, B, ϕ5) | |
@test ϕ1 ≈ ϕ2 | |
@test ϕ1 ≈ ϕ3 | |
@test ϕ1 ≈ ϕ4 | |
@test ϕ1 ≈ ϕ5 | |
end | |
# Test Summary: | Pass Total | |
# Strang splitting optimization | 4 4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment