Skip to content

Instantly share code, notes, and snippets.

@jwscook
Last active March 26, 2024 21:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jwscook/b10d9ac378a059d4a5cf714c73d50761 to your computer and use it in GitHub Desktop.
Save jwscook/b10d9ac378a059d4a5cf714c73d50761 to your computer and use it in GitHub Desktop.
Julia code to solve square or least-square linear problems Ax=b via householder QR where A can be an AbstractArray or DistributedArrays DArray and b can be a vector or a SharedArray
using Random, Distributed, Test
Random.seed!(0)
addprocs(4, exeflags="-t 2")
@everywhere begin
using LinearAlgebra, Random
using Distributed, DistributedArrays, SharedArrays
LinearAlgebra.BLAS.set_num_threads(Base.Threads.nthreads())
alphafactor(x::Real) = -sign(x)
alphafactor(x::Complex) = -exp(im * angle(x))
localcols(m::AbstractMatrix) = 1:size(m, 2)
localcols(m::DArray) = DistributedArrays.localindices(m)[2]
localindexes(A::AbstractArray) = Tuple(1:i for i in size(A))
localindexes(A::DArray) = DistributedArrays.localindices(A)
localindexes(A::SharedArray) = SharedArrays.localindices(A)
columnblocks(m::AbstractArray, n) = (@assert n == 1; return 1:size(m, 2))
columnblocks(m::DArray, n) = m.indices[n - m.pids[1] + 1][2]
Distributed.procs(::AbstractArray) = 1
localblock(A::DArray) = localpart(A)
localblock(A::SharedArray) = A
localblock(A::AbstractArray) = A
struct LocalColumnBlock{T}
Al::T
Δj::Int
colrange::UnitRange{Int}
end
function LocalColumnBlock(A::AbstractMatrix)
rowrange, colrange = localindexes(A)
@assert rowrange == 1:size(A, 1)
Δj = colrange[1] - 1
return LocalColumnBlock(localblock(A), Δj, colrange)
end
function LocalColumnBlock(A::AbstractVector)
colrange = localindexes(H)[1]
Δj = colrange[1] - 1
return LocalColumnBlock(localblock(A), Δj, colrange)
end
Base.setindex!(lcb::LocalColumnBlock, v::Number, i, j) = (lcb.Al[i, j .- lcb.Δj] = v)
Base.setindex!(lcb::LocalColumnBlock, v, i, j) = (lcb.Al[i, j .- lcb.Δj] = v)
Base.setindex!(lcb::LocalColumnBlock, v, j) = (lcb.Al[j .- lcb.Δj] = v)
Base.getindex(lcb::LocalColumnBlock, i, j) = lcb.Al[i, j .- lcb.Δj]
Base.getindex(lcb::LocalColumnBlock, j) = lcb.Al[j .- lcb.Δj]
function householder!(A, α=zeros(eltype(A), size(A, 2)))
for n in procs(A)
A, α = fetch(@spawnat n _householder!(A, α))
end
(A, α)
end
function _householder_inner!(H, j, Hj)
m, n = size(H)
Hl = LocalColumnBlock(H)
t5 = t6 = 0.0
@inbounds @views for jj in intersect(j+1:n, Hl.colrange)
t5 += @elapsed s = dot(Hj[j:m], Hl[j:m, jj])
t6 += @elapsed for i in j:m
Hl[i, jj] -= Hj[i] * s
end
end
#@show t5 / t6
return Hl
end
function _householder!(H, α)
m, n = size(H)
Hl = LocalColumnBlock(H)
t1 = t2a = t2b = 0.0
Hj = zeros(eltype(H), m)
@inbounds @views for j in Hl.colrange
t1 += @elapsed begin
s = norm(Hl[j:m, j])
α[j] = s * alphafactor(Hl[j, j])
f = 1 / sqrt(s * (s + abs(Hl[j, j])))
Hl[j, j] -= α[j]
for i in j:m
Hl[i, j] *= f
end
end
t2a += @elapsed Hj .= Hl[:, j] # copying this will make all data in end loop local
t2b += @elapsed for n in procs(H)
all(<(j), columnblocks(H, n)) && continue
if n == myid()
_householder_inner!(H, j, Hj)
else
wait(@spawnat n _householder_inner!(H, j, Hj))
end
end
end
@show t1
@show t2a
@show t2b
return (H, α)
end
function _solve_householder1_inner!(b, H, α)
m, n = size(H)
# multuply by Q' ...
Hl = LocalColumnBlock(H)
@inbounds @views for j in intersect(1:n, Hl.colrange)
s = dot(Hl[j:m, j], b[j:m])
for i in j:m
b[i] -= Hl[i, j] * s
end
end
end
function _solve_householder1!(b::Vector, H, α::Vector)
m, n = size(H)
@inbounds @views for j in 1:n
s = dot(H[j:m, j], b[j:m])
for i in j:m
b[i] -= H[i, j] * s
end
end
return b
end
function _solve_householder1!(b::SharedArray, H, α)
t3 = @elapsed @sync for p in procs(H)
wait(@spawnat p _solve_householder1_inner!(b, H, α))
end
end
function _solve_householder2!(b::Vector, H, α::Vector)
m, n = size(H)
@inbounds @views for i in n:-1:1
for j in i+1:n
b[i] -= H[i, j] * b[j]
end
b[i] /= α[i]
end
return b
end
function _solve_householder2_inner!(b, H, i)
m, n = size(H)
Hl = LocalColumnBlock(H)
js = intersect(i+1:n, Hl.colrange)
isempty(js) && return b
@views b[i] -= dot(conj.(b[js]), Hl[i, js])
return b
end
function _solve_householder2!(b::SharedArray, H, α)
m, n = size(H)
ps = procs(H)
ps = length(ps) == 1 ? ps : reverse(ps)
@inbounds @views for i in n:-1:1
@sync for p in ps
wait(@spawnat p _solve_householder2_inner!(b, H, i))
end
b[i] /= α[i]
end
return b
end
function solve_householder!(b, H, α)
m, n = size(H)
# multuply by Q' ...
t3 = @elapsed _solve_householder1!(b, H, α)
# now that b holds the value of Q'b
# we may back sub with R
t4 = @elapsed _solve_householder2!(b, H, α)
@show t3 t4
return b[1:n]
end
end
@testset "Distributed Householder QR" begin
for T in (Float64, ComplexF64), mn in ((1200, 800), (2400, 2000),)# (4800, 4000))
#for T in (Float64, ComplexF64), mn in ((3, 2), (12, 10), (24, 20), (120, 100))
#for T in (Float64, ComplexF64), mn in ((2, 2), (3, 2), (120, 100),)
m, n = mn
@show T, m, n
A = rand(T, m, n)
b = rand(T, m)
A1 = deepcopy(Matrix(A))
b1 = deepcopy(Vector(b))
t1 = @elapsed x1 = qr!(A1, NoPivot()) \ b1
A2 = deepcopy(Matrix(A))
b2 = deepcopy(Vector(b))
ta = @elapsed begin
α2 = zeros(T, n)
H, α = householder!(deepcopy(A2), deepcopy(α2))
x2 = solve_householder!(deepcopy(b2), H, α)
@show "normal timings"
@time H, α = householder!(A2, α2)
@time x2 = solve_householder!(b2, H, α)
end
# distribute A across columns
A3 = DArray(ij->A[ij[1], ij[2]], size(A), workers(), (1, nworkers()))
_A3 = DArray(ij->A[ij[1], ij[2]], size(A), workers(), (1, nworkers()))
α3 = SharedArray(zeros(T,n))#zeros(T,n)#distribute(zeros(T, n))#
_α3 = SharedArray(zeros(T,n))#zeros(T,n)#distribute(zeros(T, n))#
b3 = SharedArray(deepcopy(b))
_b3 = SharedArray(deepcopy(b))
tb = @elapsed begin
H, α = householder!(_A3, _α3)
solve_householder!(_b3, H, α)
@show "distrib timings"
@time H, α = householder!(A3, α3)
@time x3 = solve_householder!(b3, H, α)
end
@testset "serial, library" begin
@test norm(A' * A * x1 .- A' * b) < sqrt(eps())
end
@testset "serial, this" begin
@test norm(A' * A * x2 .- A' * b) < sqrt(eps())
end
@testset "distrib, this" begin
@test norm(A' * A * x3 .- A' * b) < sqrt(eps())
end
@show ta / t1
@show tb / t1
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment