Last active
September 17, 2015 21:23
-
-
Save matthieugomez/7b8b0ef85478f69b1af5 to your computer and use it in GitHub Desktop.
CG on A'A vs CG on Cimmino projection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
############################################################################## | |
## | |
## CG on (Id - AA'/diag(AA')X = 0) | |
## | |
############################################################################## | |
type CimminoProblem | |
X::Matrix{Float64} | |
invsumabs2::Vector{Float64} | |
z::Vector{Float64} | |
end | |
function CimminoProblem(X::Matrix{Float64}) | |
return CimminoProblem(X, 1./vec(sumabs2(X, 1)), Array(Float64, size(X, 1))) | |
end | |
function Base.A_mul_B!(y::Vector{Float64}, cp::CimminoProblem, x::Vector{Float64}) | |
# Multiply by (I - AA'/diag(AA')X) | |
At_mul_B(cp.z, cp.X, x) | |
broadcast!(*, cp.z, cp.z, cp.invsumabs2) | |
copy!(y, x) | |
BLAS.gemm!('N', 'N', -1.0, cp.z, tmp, 1.0, y) | |
return y | |
end | |
function cg!(x, r, A; tol::Real=1e-8, maxiter::Int=100) | |
# Initialization. | |
converged = false | |
iterations = maxiter | |
p = deepcopy(r) | |
q = similar(r) | |
tmp = Array(Float64, size(A.X, 2)) | |
ssr0 = sumabs2(r) | |
ssrold = ssr0 | |
iter = 0 | |
while iter < maxiter | |
iter += 1 | |
A_mul_B!(q, A, p) | |
α = ssrold / dot(q, p) | |
Base.BLAS.axpy!(α, p, x) | |
Base.BLAS.axpy!(-α, q, r) | |
ssr = sumabs2(r) | |
error = sumabs2(At_mul_B!(tmp, A.X, x)) | |
@show error | |
if error <= tol^2 | |
iterations = iter | |
converged = true | |
break | |
end | |
β = ssr / ssrold | |
scale!(p, β) | |
Base.BLAS.axpy!(1.0, r, p) | |
ssrold = ssr | |
end | |
return iterations, converged | |
end | |
function residualize!(y, cp::CimminoProblem) | |
# r = b- Ax0 | |
r = similar(y) | |
A_mul_B!(r, cp, y) | |
scale!(r, -1.0) | |
# start conjugate gradient | |
iterations, converged = cg!(y, r, cp) | |
@assert sumabs2(At_mul_B(cp.X, cp.X) \ At_mul_B(cp.X, y)) <= 1e-10 | |
return iterations | |
end | |
############################################################################## | |
## | |
## CG on A'A X = A'y | |
## | |
############################################################################## | |
function cgls!(r, A; tol::Real=1e-8, maxiter::Int=100) | |
# Initialization. | |
converged = false | |
iterations = maxiter | |
s = Array(Float64, size(A, 2)) | |
p = similar(s) | |
ptmp = similar(s) | |
q = similar(r) | |
invdiag = 1./vec(sumabs2(A, 1)) | |
At_mul_B!(s, A, r) | |
broadcast!(*, ptmp, s, invdiag) | |
copy!(p, ptmp) | |
ssr0 = dot(s, ptmp) | |
ssrold = ssr0 | |
iter = 0 | |
tmp = similar(s) | |
while iter < maxiter | |
iter += 1 | |
A_mul_B!(q, A, p) | |
At_mul_B!(ptmp, A, q) | |
α = ssrold / dot(ptmp, p) | |
Base.BLAS.axpy!(-α, q, r) | |
Base.BLAS.axpy!(-α, ptmp, s) | |
broadcast!(*, ptmp, s, invdiag) | |
ssr = dot(s, ptmp) | |
error = sumabs2(At_mul_B!(tmp, A, r)) | |
@show error | |
if error <= tol^2 | |
iterations = iter | |
converged = true | |
break | |
end | |
β = ssr / ssrold | |
# p = s + β p | |
scale!(p, β) | |
Base.BLAS.axpy!(1.0, ptmp, p) | |
ssrold = ssr | |
end | |
return iterations, converged | |
end | |
function residualize!(y, X::Matrix{Float64}) | |
# start conjugate gradient | |
iterations, converged = cgls!(y, X) | |
@assert sumabs2(At_mul_B(X, X) \ At_mul_B(X, y)) <= 1e-10 | |
return iterations | |
end | |
residualize(y, X) = residualize!(deepcopy(y), X) | |
############################################################################## | |
## | |
## Tests | |
## For both methods, error stopped when sumabs2(A' x residual) <= tol | |
## Both method require the same amount of computations at each iteration | |
## | |
############################################################################## | |
X = randn(500, 2) | |
y = randn(500) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(5000, 2) | |
y = randn(5000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(50000, 2) | |
y = randn(50000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(500, 10) | |
y = randn(500) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(5000, 10) | |
y = randn(5000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(50000, 10) | |
y = randn(50000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(500, 100) | |
y = randn(500) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(5000, 100) | |
y = randn(5000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(50000, 100) | |
y = randn(50000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(500, 1000) | |
y = randn(500) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(5000, 1000) | |
y = randn(5000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) | |
X = randn(50000, 1000) | |
y = randn(50000) | |
residualize(y, CimminoProblem(X)) | |
residualize(y, X) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment