Skip to content

Instantly share code, notes, and snippets.

@c42f
Last active August 29, 2015 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save c42f/40356c153e90ff5071bc to your computer and use it in GitHub Desktop.
Save c42f/40356c153e90ff5071bc to your computer and use it in GitHub Desktop.
Cholesky decomposition for autodiff
using DualNumbers
# Solve the matrix system
#
# B = U'*M + M'*U
#
# for M where B is symmetric and U is upper triangular. This looks a bit like
# the *-Sylvester equation, but has significantly more symmetry.
function tri_ss_solve!{T<:Base.LinAlg.BlasFloat}(M::AbstractArray{T,2}, U::AbstractArray{T,2}, B::AbstractArray{T,2})
# This turns out to be a forward substitution algorithm in terms of the
# upper triangular superoperator UU[M] := U'*M + M'*U
n = size(U,1)
a = zeros(n)
mi = zeros(n)
ui = zeros(n)
unit = one(eltype(U))
# Compute M row by row
for i = 1:n
m = n-i+1
# a[1:m] = B[i,i:n]
for k=1:m
a[k] = B[i,i+k-1]
end
if i > 1
# a -= M[1:i-1,i]'*U[1:i-1,i:end] + U[1:i-1,i]'*M[1:i-1,i:end]
for k=1:i-1
# Nonzero off diagonal parts of i'th column of M have been
# computed in previous iterations.
mi[k] = M[k,i]
ui[k] = U[k,i]
end
# Call BLAS directly to avoid temporary array copies.
BLAS.gemv!('T', -unit, sub(U,1:i-1,i:n), mi, unit, a)
BLAS.gemv!('T', -unit, sub(M,1:i-1,i:n), ui, unit, a)
end
# Special case for diagonal element
M[i,i] = a[1]/(2*U[i,i])
# Off-diagonals
# M[i,i+1:end] = (a[1,2:end] - M[i,i]*U[i,i+1:end]) / U[i,i]
for k=2:m
M[i,i+k-1] = (a[k] - M[i,i]*U[i,i+k-1]) / U[i,i]
end
end
return M
end
# Solve the matrix system
#
# B = U'*M + M'*U
#
# for M, for general matrix types.
function tri_ss_solve!(M, U, B)
n = size(U,1)
# Compute M row by row
for i = 1:n
a = B[i,i:end]
if i > 1
a -= M[1:i-1,i]'*U[1:i-1,i:end] + U[1:i-1,i]'*M[1:i-1,i:end]
end
M[i,i] = a[1]/(2*U[i,i])
M[i,i+1:end] = (a[1,2:end] - M[i,i]*U[i,i+1:end]) / U[i,i]
end
return M
end
# Factorize A + Bdu as (U + Mdu)' * (U + Mdu)
#
# In the convention here, the returned U and M are upper-triangular.
function chol_dual(A, B)
# Convert U from triangular to full to allow for direct indexing
U = chol(A)
M = zeros(size(A))
tri_ss_solve!(M, U, B)
return (U,M)
end
function DualNumbers.Dual{T<:Real}(A::AbstractArray{T}, B::AbstractArray{T})
@assert size(A) == size(B)
[Dual(A[i,j],B[i,j]) for i=1:size(A,1),j=1:size(A,2)]
end
function chol_dual_test(n = 1000)
A = rand(n,n)
B = rand(n,n)
# Make A & B symmetric; normalize elements close to one.
A = A*A' * 4/n
B = B*B' * 4/n
print("Time for builtin chol():\n")
@time chol(A)
print("Time for dual_chol():\n")
@time U,M = chol_dual(A,B)
AB = Dual(A,B)
UM = Dual(U,M)
# Compute residual
resid = AB - UM'*UM
maxRealError = maximum(abs(map(epsilon, resid)))
maxEpsilonError = maximum(abs(map(real, resid)))
print("Residual of reconstructed matrix A:\n")
print("max real error = $maxRealError\n")
print("max epsilon error = $maxEpsilonError\n")
nothing
end
#chol_dual_test(1000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment