Skip to content

Instantly share code, notes, and snippets.

@YingboMa
Last active February 17, 2019 17:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YingboMa/1ceb9c01539ece2456bb33e3f3cede3b to your computer and use it in GitHub Desktop.
Save YingboMa/1ceb9c01539ece2456bb33e3f3cede3b to your computer and use it in GitHub Desktop.
Sivan Toledo's recursive LU algorithm
using LinearAlgebra
lurec(A, blocksize=16) = lurec!(copy(A), Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)), blocksize)
function lurec!(A::AbstractMatrix{T}, ipiv, blocksize) where T
info = Ref(zero(LinearAlgebra.BlasInt))
m, n = size(A)
mnmin = min(m, n)
reckernel!(A, m, mnmin, ipiv, info, blocksize)
LU{T, typeof(A)}(A, ipiv, info[])
end
nsplit(::Type{Float64}, n) = n >= 16 ? ((n + 8) ÷ 16) * 8 : n ÷ 2
nsplit(::Type{Float32}, n) = n >= 32 ? ((n + 16) ÷ 32) * 16 : n ÷ 2
nsplit(::Type{ComplexF64}, n) = n >= 8 ? ((n + 4) ÷ 8) * 4 : n ÷ 2
nsplit(::Type{ComplexF32}, n) = n >= 16 ? ((n + 8) ÷ 16) * 8 : n ÷ 2
Base.@propagate_inbounds function apply_permutation!(P, A)
for i in axes(P, 1)
i′ = P[i]
i′ == i && continue
@simd for j in axes(A, 2)
A[i, j], A[i′, j] = A[i′, j], A[i, j]
end
end
nothing
end
function reckernel!(A::AbstractMatrix{T}, m, n, ipiv, info, blocksize)::Nothing where T
@inbounds begin
if n <= max(blocksize, 1)
_generic_lufact!(A, ipiv, info)
return nothing
end
n1 = nsplit(T, n)
n2 = n - n1
m2 = m - n1
# ======================================== #
# Now, our LU process looks like this
# [ P1 ] [ A11 A21 ] [ L11 0 ] [ U11 U12 ]
# [ ] [ ] = [ ] [ ]
# [ P2 ] [ A21 A22 ] [ L21 I ] [ 0 A′22 ]
# ======================================== #
# ======================================== #
# Partition the matrix A
# [AL AR]
AL = @view A[:, 1:n1]
AR = @view A[:, n1+1:n]
# AL AR
# [A11 A12]
# [A21 A22]
A11 = @view A[1:n1, 1:n1]
A12 = @view A[1:n1, n1+1:n]
A21 = @view A[n1+1:m, 1:n1]
A22 = @view A[n1+1:m, n1+1:n]
# [P1]
# [P2]
P1 = @view ipiv[1:n1]
P2 = @view ipiv[n1+1:n]
# ========================================
# [ A11 ] [ L11 ]
# P [ ] = [ ] U11
# [ A21 ] [ L21 ]
reckernel!(AL, m, n1, P1, info, blocksize)
# [ A12 ] [ P1 ] [ A12 ]
# [ ] <- [ ] [ ]
# [ A22 ] [ 0 ] [ A22 ]
apply_permutation!(P1, AR)
# A12 = L11 U12 => U12 = L11 \ A12
ldiv!(LinearAlgebra.UnitLowerTriangular(A11), A12)
# Schur complement:
# We have A22 = L21 U12 + A′22, hence
# A′22 = A22 - L21 U12
BLAS.gemm!('N', 'N', -one(T), A21, A12, one(T), A22)
# record info
previnfo = info[]
# P2 A22 = L22 U22
reckernel!(A22, m2, n2, P2, info, blocksize)
# A21 <- P2 A21
apply_permutation!(P2, A21)
info[] != previnfo && (info[] += n1)
@simd for i in 1:n2
P2[i] += n1
end
return nothing
end # inbounds
end
#=
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
License is MIT: https://julialang.org/license
=#
function _generic_lufact!(A::StridedMatrix{T}, ipiv, info) where T
m, n = size(A)
minmn = length(ipiv)
@inbounds begin
for k = 1:minmn
# find index max
kp = k
amax = abs(zero(T))
for i = k:m
absi = abs(A[i,k])
if absi > amax
kp = i
amax = absi
end
end
ipiv[k] = kp
if !iszero(A[kp,k])
if k != kp
# Interchange
@simd for i = 1:n
tmp = A[k,i]
A[k,i] = A[kp,i]
A[kp,i] = tmp
end
end
# Scale first column
Akkinv = inv(A[k,k])
@simd for i = k+1:m
A[i,k] *= Akkinv
end
elseif info[] == 0
info[] = k
end
# Update the rest
for j = k+1:n
@simd for i = k+1:m
A[i,j] -= A[i,k]*A[k,j]
end
end
end
end
return nothing
end
#=
julia> using BenchmarkTools
julia> A = rand(80,80);
julia> F = lurec(A);
julia> F.L * F.U - A[F.p, :] |> norm
1.072510951834373e-14
julia> A = rand(500, 500); A[:, 323] .= 0;
julia> lu(A, check=false).info == lurec(A).info
true
julia> @btime lurec($A);
89.448 μs (39 allocations: 52.86 KiB)
julia> @btime lu($A);
419.979 μs (4 allocations: 50.83 KiB)
julia> 419.979/89.448
4.695230748591361
julia> A = rand(200, 200);
julia> @btime lurec($A);
548.803 μs (109 allocations: 320.47 KiB)
julia> @btime lu($A);
1.524 ms (4 allocations: 314.38 KiB)
julia> 1.524*1000/548.803
2.77695274989386
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment