Skip to content

Instantly share code, notes, and snippets.

@stillyslalom
Last active March 17, 2021 19:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stillyslalom/2e7e0a5483a847ba3206f66dc3a740fa to your computer and use it in GitHub Desktop.
Save stillyslalom/2e7e0a5483a847ba3206f66dc3a740fa to your computer and use it in GitHub Desktop.
using LinearAlgebra, LoopVectorization, Test
function f1(M, G, J, H, A, B, ϕ)
for mm = 1:M
tempmatr = A \ (
reshape(
permutedims(
ϕ[:, 2:G, :, M + 2 - mm],
[2 1 3]),
G - 1, :)
+
reshape(
permutedims(
B[:, :, :, M + 2 - mm],
[2 1 3]),
G - 1, :)
)
ϕ[:, 2:G, :, M + 1 - mm] = permutedims(
reshape(
tempmatr,
G - 1,
2 * J + 1,
H + 1),
[2 1 3]
)
end
end
function f2(M, G, J, H, A, B, ϕ)
tmp = similar(ϕ, G-1, (2*J+1)*(H + 1))
 = factorize(A)
for mm = 1:M
mp1 = M + 1 - mm
mp2 = mp1 + 1
@inbounds for hh = 1:H+1
for jj = 1:2*J+1
for gg = 1:G-1
tmp[gg, jj + (hh - 1)*(2*J + 1)] = ϕ[jj, gg+1, hh, mp2] +
B[jj, gg, hh, mp2]
end
end
end
ldiv!(Â, tmp)
@inbounds for hh = 1:H+1
for jj = 1:2*J+1
for gg = 1:G-1
ϕ[jj, gg + 1, hh, mp1] = tmp[gg, jj + (hh - 1)*(2*J + 1)]
end
end
end
end
end
function f3(M, G, J, H, A, B, ϕ)
tmp = similar(ϕ, G-1, (2*J+1)*(H + 1))
LU = factorize(A)
for mm = 1:M
mp1 = M + 1 - mm
mp2 = mp1 + 1
@inbounds Threads.@threads for hh = 1:H+1
for jj = 1:2*J+1
for gg = 1:G-1
tmp[gg, jj + (hh - 1)*(2*J + 1)] = ϕ[jj, gg+1, hh, mp2] +
B[jj, gg, hh, mp2]
end
end
end
ldiv!(LU, tmp)
@inbounds Threads.@threads for hh = 1:H+1
for jj = 1:2*J+1
for gg = 1:G-1
ϕ[jj, gg + 1, hh, mp1] = tmp[gg, jj + (hh - 1)*(2*J + 1)]
end
end
end
end
end
function f4(M, G, J, H, A, B, ϕ)
 = factorize(A)
Âs = [deepcopy(Â) for i = 1:Threads.nthreads()]
tmp = [similar(ϕ, G-1) for i = 1:Threads.nthreads()] # thread-local storage
Bf = reinterpret(reshape, Float64, B)
ϕf = reinterpret(reshape, Float64, ϕ)
tmpf = reinterpret.(reshape, Float64, tmp)
jmax = 2*J + 1
@inbounds for mm = 1:M
mp1 = M + 1 - mm
mp2 = mp1 + 1
Threads.@threads for hh = 1:H+1
tmpl = tmp[Threads.threadid()]
tmpfl = tmpf[Threads.threadid()]
Âl = Âs[Threads.threadid()]
h_idx = (hh - 1)*jmax
for jj = 1:jmax
@avx for gg = 1:G-1
tmpfl[1, gg] = ϕf[1, jj, gg+1, hh, mp2] +
Bf[1, jj, gg, hh, mp2]
tmpfl[2, gg] = ϕf[2, jj, gg+1, hh, mp2] +
Bf[2, jj, gg, hh, mp2]
end
ldiv!(Âl, tmpl)
@avx for gg = 1:G-1
ϕf[1, jj, gg + 1, hh, mp1] = tmpfl[1, gg]
ϕf[2, jj, gg + 1, hh, mp1] = tmpfl[2, gg]
end
end
end
end
end
function f5(M, G, J, H, A, B, ϕ)
 = factorize(A)
Âs = [deepcopy(Â) for i = 1:Threads.nthreads()]
tmp = [similar(ϕ, G-1) for i = 1:Threads.nthreads()] # thread-local storage
jmax = 2*J + 1
@inbounds for mm = 1:M
mp1 = M + 1 - mm
mp2 = mp1 + 1
Threads.@threads for hh = 1:H+1
tmpl = tmp[Threads.threadid()]
Âl = Âs[Threads.threadid()]
h_idx = (hh - 1)*jmax
@inbounds for jj = 1:jmax
for gg = 1:G-1
tmpl[gg] = ϕ[jj, gg+1, hh, mp2] +
B[jj, gg, hh, mp2]
end
ldiv!(Âl, tmpl)
for gg = 1:G-1
ϕ[jj, gg + 1, hh, mp1] = tmpl[gg]
end
end
end
end
end
M = 10
G = 50
J = 50
H = 30
# This was faster without sparsity of A, not sure why
A = Matrix(Tridiagonal(rand(G-1,G-1)));
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1);
ϕ1 = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1);
ϕ2, ϕ3, ϕ4, ϕ5 = copy(ϕ1), copy(ϕ1), copy(ϕ1), copy(ϕ1)
@testset "Strang splitting optimization" begin
f1(M, G, J, H, A, B, ϕ1)
f2(M, G, J, H, A, B, ϕ2)
f3(M, G, J, H, A, B, ϕ3)
f4(M, G, J, H, A, B, ϕ4)
f5(M, G, J, H, A, B, ϕ5)
@test ϕ1 ≈ ϕ2
@test ϕ1 ≈ ϕ3
@test ϕ1 ≈ ϕ4
@test ϕ1 ≈ ϕ5
end
# Test Summary: | Pass Total
# Strang splitting optimization | 4 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment