Skip to content

Instantly share code, notes, and snippets.

@dpo
Created March 3, 2019 22:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dpo/485a811f1705e7338e305ac6f70b22e9 to your computer and use it in GitHub Desktop.
Save dpo/485a811f1705e7338e305ac6f70b22e9 to your computer and use it in GitHub Desktop.
""" (x, flags, stats) = usymlqr(A, b, c)
Solve the symmetric saddle-point system
[ I A ] [ s ] = [ b ]
[ A' ] [ t ] [ c ]
by way of the Saunders-Simon-Yip tridiagonalization using the USYMQR and USYMLQ methods.
The method solves the least-squares problem
[ I A ] [ r ] = [ b ]
[ A' ] [ x ] [ 0 ]
and the least-norm problem
[ I A ] [ y ] = [ 0 ]
[ A' ] [ z ] [ c ]
and simply adds the solutions.
M. A. Saunders, H. D. Simon and E. L. Yip
Two Conjugate-Gradient-Type Methods for Unsymmetric Linear Equations
SIAM Journal on Numerical Analysis, 25(4), 927-940, 1988.
"""
function usymlqr(A, b::Vector{Float64}, c::Vector{Float64};
itnlim::Int=maximum(size(A)),
atol_ls::Float64=1.0e-6, rtol_ls::Float64=1.0e-6,
atol_ln::Float64=1.0e-6, rtol_ln::Float64=1.0e-6,
sigma::Float64=0.0,
conlim::Float64=1.0e+8, verbose::Bool=false)
m, n = size(A)
if length(b) != m || length(c) != n
error("USYMLQR: Dimensions mismatch")
end
# Exit fast if b or c is zero.
beta1 = norm(b)
beta1 > 0.0 || error("USYMLQR: b must be nonzero.")
gamma1 = norm(c)
gamma1 > 0.0 || error("USYMLQR: c must be nonzero.")
iter = 0
ctol = conlim > 0.0 ? 1/conlim : 0.0
ls_zero_resid_tol = atol_ls + rtol_ls * beta1
ls_optimality_tol = atol_ls + rtol_ls * norm(A' * b) # FIXME
ln_tol = atol_ln + rtol_ln * gamma1
@debug ls_tol
@debug ln_tol
@info @sprintf("USYMLQR with %d rows and %d columns", m, n)
# Initial SSY vectors.
u_prev = fill!(similar(b), 0)
v_prev = fill!(similar(c), 0)
u = b / beta1
v = c / gamma1 # u₁, v₁.
q = A * v
alpha = dot(u, q) # alpha₁
vv = copy(v)
beta = beta1
gamma = gamma1
# initial norm estimates
Anorm2 = alpha * alpha
Anorm = abs(alpha)
sigma_min = sigma_max = alpha # extreme singular values estimates.
Acond = 1.0
# initial residual of least-squares problem
phibar = beta1
rNorm_qr = phibar
rNorms_qr = [rNorm_qr]
ArNorm_qr = 0.0 # just so it exists at the end of the loop!
ArNorms_qr = Float64[]
# initialization for QR factorization of T{k+1,k}
cs = -1.0
sn = 0.0
deltabar = alpha
lambda = 0.0
epsilon = 0.0
eta = 0.0
if verbose
@printf("%4s %8s %7s %7s %7s %7s %7s %7s %7s\n",
"iter", "alpha", "beta", "gamma", "‖A‖", "κ(A)", "‖Ax-b‖", "‖A'r‖", "‖A'y-c‖")
@printf("%4d %8.1e %7.1e %7.1e %7.1e %7.1e %7.1e ",
iter, alpha, beta, gamma, Anorm, Acond, rNorm_qr)
end
# initialize x and z update directions
x = fill!(similar(c), 0)
xNorm = 0.0
z = fill!(similar(c), 0)
wbar = v / deltabar
w = fill!(similar(v), 0)
wold = fill!(similar(v), 0)
Wnorm2 = 0.0
@debug "" iter phibar wbar'
# quantities related to the update of y
etabar = gamma / deltabar
p = fill!(similar(u), 0)
pbar = copy(u)
y = fill!(similar(b), 0)
yC = etabar * pbar
zC = -etabar* wbar
@debug "‖A * zC + yC‖" norm(A * zC + yC)
yNorm2 = 0.0
yNorm = 0.0
@debug "" y
# quantities related to the computation of ‖x‖
# TODO
# quantities related to regularization
if sigma != 0
deltahat = alpha
psibar = 0.0
end
# residual of the least-norm problem
rNorm_lq = 2 * ln_tol # just so it exists at the end of the loop!
rNorms_lq = Float64[]
status = "unknown"
transition_to_cg = false
# stopping conditions that apply to both problems
tired = iter ≥ itnlim
ill_cond_lim = 1/Acond ≤ ctol
ill_cond_mach = 1.0 + 1/Acond ≤ 1.0
ill_cond = ill_cond_mach | ill_cond_lim
# stopping conditions related to the least-squares problem
test_LS = rNorm_qr / (1.0 + Anorm * xNorm)
zero_resid_lim_LS = test_LS ≤ ls_zero_resid_tol
zero_resid_mach_LS = 1.0 + test_LS ≤ 1.0
zero_resid_LS = zero_resid_mach_LS | zero_resid_lim_LS
test_LS = ArNorm_qr / (Anorm * max(1.0, rNorm_qr))
solved_lim_LS = test_LS ≤ ls_optimality_tol
solved_mach_LS = 1.0 + test_LS ≤ 1.0
# TODO: check this
solved_LS = false # solved_mach_LS | solved_lim_LS | zero_resid_LS
# stopping conditions related to the least-norm problem
test_LN = rNorm_lq / sqrt(gamma1^2 + Anorm2 * yNorm2)
solved_lim_LN = test_LN ≤ ln_tol
solved_mach_LN = 1.0 + test_LN ≤ 1.0
# TODO: check this
solved_LN = false # solved_lim_LN | solved_mach_LN
solved = solved_LS & solved_LN
# TODO: remove this when finished
tests_LS = Float64[]
tests_LN = Float64[]
while ! (solved | tired | ill_cond)
iter = iter + 1
# continue tridiagonalization
@. u_prev = u
@. u = q - alpha * u
Atuprev = A' * u_prev
@. v = Atuprev - alpha * v - beta * v_prev
beta = norm(u)
if beta > 0
@. u /= beta
end
gamma = norm(v)
if gamma > 0
@. v /= gamma
end
# save vectors for next iteration
@. v_prev = vv
@. vv = v
# continue QR factorization of T{k+1,k}
lambdabar = -cs * gamma
epsilon = sn * gamma
@debug "" lambdabar epsilon
# compute optimality residual of least-squares problem at x{k-1}
# TODO: use recurrence formula for QR residual
if !solved_LS
ArNorm_qr_computed = rNorm_qr * sqrt(deltabar^2 + lambdabar^2)
ArNorm_qr = norm(A' * (b - A * x)) # FIXME
@debug "" ArNorm_qr_computed ArNorm_qr abs(ArNorm_qr_computed - ArNorm_qr) / ArNorm_qr
ArNorm_qr = ArNorm_qr_computed
push!(ArNorms_qr, ArNorm_qr)
test_LS = ArNorm_qr / (Anorm * max(1.0, rNorm_qr))
solved_lim_LS = test_LS ≤ ls_optimality_tol
solved_mach_LS = 1.0 + test_LS ≤ 1.0
solved_LS = solved_mach_LS | solved_lim_LS
# TODO: remove this when finished
push!(tests_LS, test_LS)
if solved_LS
@info "solved LS problem with" x
end
end
verbose && @printf("%7.1e ", ArNorm_qr)
# perform rotations related to regularization if applicable
if sigma != 0.0
@show lambdabar
# lambdahat = lambdabar
# first rotation
deltabar = sqrt(deltahat^2 + sigma^2)
cbar = deltahat / deltabar
sbar = sigma / deltabar
lambdabar = cbar * lambdahat
sbar_lambdahat = sbar * lambdahat
phitilde = cbar * phibar + sbar * psibar
psitilde = sbar * phibar - cbar * psibar
# second rotation
sigmabar = sqrt(sbar_lambdahat^2 + sigma^2)
ctilde = -sigma / sigmabar
stilde = sbar_lambdahat/ sigmabar
psi = ctilde * psibar
psibar = stilde * psibar
end
# continue QR factorization
delta = sqrt(deltabar^2 + beta^2)
csold = cs
snold = sn
cs = deltabar/ delta
sn = beta / delta
# if debug_qr
# println("delta = ", delta)
# println("cs = ", cs)
# println("sn = ", sn)
# end
# update w (used to update x and z)
@. wold = w
@. w = cs * wbar
if !solved_LS
# the optimality conditions of the LS problem were not triggerred
# update x and see if we have a zero residual
phi = cs * phibar
phibar = sn * phibar
@. x += phi * w
xNorm = norm(x) # FIXME
# if debug_qr
# println("ArNorm_qr = ", ArNorm_qr)
# println("w ="); display(w'); println()
# println("phi = ", phi)
# println("phibar= ", phibar)
# println("x = "); display(x'); println()
# end
# update least-squares residual
rNorm_qr = abs(phibar)
push!(rNorms_qr, rNorm_qr)
# stopping conditions related to the least-squares problem
test_LS = rNorm_qr / (1.0 + Anorm * xNorm)
zero_resid_lim_LS = test_LS ≤ ls_zero_resid_tol
zero_resid_mach_LS = 1.0 + test_LS ≤ 1.0
zero_resid_LS = zero_resid_mach_LS | zero_resid_lim_LS
solved_LS |= zero_resid_LS
if zero_resid_LS
@info "solved LS problem to zero residual with" x
end
end
# continue tridiagonalization
q = A * v
@. q -= gamma * u_prev
alpha = dot(u, q)
# update norm estimates
Anorm2 += alpha * alpha + beta * beta + gamma * gamma
Anorm = sqrt(Anorm2)
# Wnorm2 += dot(w, w)
# Acond = Anorm * sqrt(Wnorm2)
# Estimate κ₂(A) based on the diagonal of L.
sigma_min = min(delta, sigma_min)
sigma_max = max(delta, sigma_max)
Acond = sigma_max / sigma_min
# continue QR factorization of T{k+1,k}
lambda = cs * lambdabar + sn * alpha
deltabar= sn * lambdabar - cs * alpha
@debug "" lambda deltabar
if !solved_LN
etaold = eta
eta = cs * etabar # = etak
# compute residual of least-norm problem at y{k-1}
# TODO: use recurrence formula for LQ residual
rNorm_lq_computed = sqrt((delta * eta)^2 + (epsilon * etaold)^2)
rNorm_lq = norm(A' * y - c) # FIXME
@debug "" rNorm_lq_computed rNorm_lq abs(rNorm_lq_computed - rNorm_lq) / rNorm_lq
rNorm_lq = rNorm_lq_computed
push!(rNorms_lq, rNorm_lq)
# stopping conditions related to the least-norm problem
test_LN = rNorm_lq / sqrt(gamma1^2 + Anorm2 * yNorm2)
solved_lim_LN = test_LN ≤ ln_tol
solved_mach_LN = 1.0 + test_LN ≤ 1.0
solved_LN = solved_lim_LN | solved_mach_LN
# TODO: remove this when finished
push!(tests_LN, test_LN)
if solved_LN
@info "solved LN problem with" y z
end
@. wbar = (v - lambda * w - epsilon * wold) / deltabar
@debug wbar
if !solved_LN
# prepare to update y and z
@. p = cs * pbar + sn * u
# update y and z
@. y += eta * p
@. z -= eta * w
yNorm2 += eta * eta
yNorm = sqrt(yNorm2)
@debug y
@. pbar = sn * pbar - cs * u
etabarold = etabar
etabar = -(lambda * eta + epsilon * etaold) / deltabar # = etabar{k+1}
# see if CG iterate has smaller residual
# TODO: use recurrence formula for CG residual
@. yC = y + etabar* pbar
@. zC = z - etabar* wbar
yCNorm2 = yNorm2 + etabar* etabar
rNorm_cg_computed = gamma * abs(snold * etaold - csold * etabarold)
rNorm_cg = norm(A' * yC - c)
@debug "" rNorm_cg_computed rNorm_cg
# if rNorm_cg < rNorm_lq
# # stopping conditions related to the least-norm problem
# test_cg = rNorm_cg / sqrt(gamma1^2 + Anorm2 * yCNorm2)
# solved_lim_LN = test_cg ≤ ln_tol
# solved_mach_LN = 1.0 + test_cg ≤ 1.0
# solved_LN = solved_lim_LN | solved_mach_LN
# # transition_to_cg = solved_LN
# transition_to_cg = false
# end
if transition_to_cg
# @. yC = y + etabar* pbar
# @. zC = z - etabar* wbar
@info "solved LN problem with CG point" yC zC
end
end
end
verbose && @printf("%7.1e\n", rNorm_lq)
verbose && @printf("%4d %8.1e %7.1e %7.1e %7.1e %7.1e %7.1e ",
iter, alpha, beta, gamma, Anorm, Acond, rNorm_qr)
# stopping conditions that apply to both problems
tired = iter ≥ itnlim
ill_cond_lim = 1/Acond ≤ ctol
ill_cond_mach = 1.0 + 1/Acond ≤ 1.0
ill_cond = ill_cond_mach | ill_cond_lim
solved = solved_LS & solved_LN
end
verbose && @printf("\n")
@info "final LS status" zero_resid_lim_LS zero_resid_mach_LS solved_lim_LS solved_mach_LS
@info "final LN status" solved_lim_LN solved_mach_LN
@info "" solved tired ill_cond
# at the very end, recover r, yC and zC
r = b - A * x
# yC = y + etabar* pbar # these might suffer from cancellation
# zC = z - etabar* wbar # if the last step is small
return (x, r, y, z, yC, zC, rNorms_qr, ArNorms_qr, rNorms_lq, tests_LS, tests_LN)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment