Last active
February 17, 2019 17:22
-
-
Save YingboMa/1ceb9c01539ece2456bb33e3f3cede3b to your computer and use it in GitHub Desktop.
Sivan Toledo's recursive LU algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
lurec(A, blocksize=16) = lurec!(copy(A), Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)), blocksize) | |
function lurec!(A::AbstractMatrix{T}, ipiv, blocksize) where T | |
info = Ref(zero(LinearAlgebra.BlasInt)) | |
m, n = size(A) | |
mnmin = min(m, n) | |
reckernel!(A, m, mnmin, ipiv, info, blocksize) | |
LU{T, typeof(A)}(A, ipiv, info[]) | |
end | |
nsplit(::Type{Float64}, n) = n >= 16 ? ((n + 8) ÷ 16) * 8 : n ÷ 2 | |
nsplit(::Type{Float32}, n) = n >= 32 ? ((n + 16) ÷ 32) * 16 : n ÷ 2 | |
nsplit(::Type{ComplexF64}, n) = n >= 8 ? ((n + 4) ÷ 8) * 4 : n ÷ 2 | |
nsplit(::Type{ComplexF32}, n) = n >= 16 ? ((n + 8) ÷ 16) * 8 : n ÷ 2 | |
Base.@propagate_inbounds function apply_permutation!(P, A) | |
for i in axes(P, 1) | |
i′ = P[i] | |
i′ == i && continue | |
@simd for j in axes(A, 2) | |
A[i, j], A[i′, j] = A[i′, j], A[i, j] | |
end | |
end | |
nothing | |
end | |
function reckernel!(A::AbstractMatrix{T}, m, n, ipiv, info, blocksize)::Nothing where T | |
@inbounds begin | |
if n <= max(blocksize, 1) | |
_generic_lufact!(A, ipiv, info) | |
return nothing | |
end | |
n1 = nsplit(T, n) | |
n2 = n - n1 | |
m2 = m - n1 | |
# ======================================== # | |
# Now, our LU process looks like this | |
# [ P1 ] [ A11 A21 ] [ L11 0 ] [ U11 U12 ] | |
# [ ] [ ] = [ ] [ ] | |
# [ P2 ] [ A21 A22 ] [ L21 I ] [ 0 A′22 ] | |
# ======================================== # | |
# ======================================== # | |
# Partition the matrix A | |
# [AL AR] | |
AL = @view A[:, 1:n1] | |
AR = @view A[:, n1+1:n] | |
# AL AR | |
# [A11 A12] | |
# [A21 A22] | |
A11 = @view A[1:n1, 1:n1] | |
A12 = @view A[1:n1, n1+1:n] | |
A21 = @view A[n1+1:m, 1:n1] | |
A22 = @view A[n1+1:m, n1+1:n] | |
# [P1] | |
# [P2] | |
P1 = @view ipiv[1:n1] | |
P2 = @view ipiv[n1+1:n] | |
# ======================================== | |
# [ A11 ] [ L11 ] | |
# P [ ] = [ ] U11 | |
# [ A21 ] [ L21 ] | |
reckernel!(AL, m, n1, P1, info, blocksize) | |
# [ A12 ] [ P1 ] [ A12 ] | |
# [ ] <- [ ] [ ] | |
# [ A22 ] [ 0 ] [ A22 ] | |
apply_permutation!(P1, AR) | |
# A12 = L11 U12 => U12 = L11 \ A12 | |
ldiv!(LinearAlgebra.UnitLowerTriangular(A11), A12) | |
# Schur complement: | |
# We have A22 = L21 U12 + A′22, hence | |
# A′22 = A22 - L21 U12 | |
BLAS.gemm!('N', 'N', -one(T), A21, A12, one(T), A22) | |
# record info | |
previnfo = info[] | |
# P2 A22 = L22 U22 | |
reckernel!(A22, m2, n2, P2, info, blocksize) | |
# A21 <- P2 A21 | |
apply_permutation!(P2, A21) | |
info[] != previnfo && (info[] += n1) | |
@simd for i in 1:n2 | |
P2[i] += n1 | |
end | |
return nothing | |
end # inbounds | |
end | |
#= | |
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl | |
License is MIT: https://julialang.org/license | |
=# | |
function _generic_lufact!(A::StridedMatrix{T}, ipiv, info) where T | |
m, n = size(A) | |
minmn = length(ipiv) | |
@inbounds begin | |
for k = 1:minmn | |
# find index max | |
kp = k | |
amax = abs(zero(T)) | |
for i = k:m | |
absi = abs(A[i,k]) | |
if absi > amax | |
kp = i | |
amax = absi | |
end | |
end | |
ipiv[k] = kp | |
if !iszero(A[kp,k]) | |
if k != kp | |
# Interchange | |
@simd for i = 1:n | |
tmp = A[k,i] | |
A[k,i] = A[kp,i] | |
A[kp,i] = tmp | |
end | |
end | |
# Scale first column | |
Akkinv = inv(A[k,k]) | |
@simd for i = k+1:m | |
A[i,k] *= Akkinv | |
end | |
elseif info[] == 0 | |
info[] = k | |
end | |
# Update the rest | |
for j = k+1:n | |
@simd for i = k+1:m | |
A[i,j] -= A[i,k]*A[k,j] | |
end | |
end | |
end | |
end | |
return nothing | |
end | |
#= | |
julia> using BenchmarkTools | |
julia> A = rand(80,80); | |
julia> F = lurec(A); | |
julia> F.L * F.U - A[F.p, :] |> norm | |
1.072510951834373e-14 | |
julia> A = rand(500, 500); A[:, 323] .= 0; | |
julia> lu(A, check=false).info == lurec(A).info | |
true | |
julia> @btime lurec($A); | |
89.448 μs (39 allocations: 52.86 KiB) | |
julia> @btime lu($A); | |
419.979 μs (4 allocations: 50.83 KiB) | |
julia> 419.979/89.448 | |
4.695230748591361 | |
julia> A = rand(200, 200); | |
julia> @btime lurec($A); | |
548.803 μs (109 allocations: 320.47 KiB) | |
julia> @btime lu($A); | |
1.524 ms (4 allocations: 314.38 KiB) | |
julia> 1.524*1000/548.803 | |
2.77695274989386 | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment