-
-
Save GunnarFarneback/2c2a38b3715279660bfe to your computer and use it in GitHub Desktop.
Implementation of real matrix square root algorithm by Higham 1987.The code was written at the time of #2246 so parts of it look a bit old-fashioned now. For full performance several constructions must be devectorized.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function real_sqrtm(T::Matrix, Q::Matrix, n::Int, cond::Bool) | |
R = zeros(eltype(T), n, n) | |
num_blocks = n - sum(diag(T, -1) .!= 0) | |
blocks = zeros(num_blocks) | |
block_sizes = zeros(num_blocks) | |
negative_real_eigenvalue_found = false | |
i = 1 | |
for j = 1:num_blocks | |
blocks[j] = i | |
if i < n && T[i+1,i] != 0 | |
block_sizes[j] = 2 | |
i += 2 | |
else | |
block_sizes[j] = 1 | |
if T[i,i] < 0 | |
negative_real_eigenvalue_found = true | |
break | |
end | |
i += 1 | |
end | |
end | |
if negative_real_eigenvalue_found | |
# FIXME: Convert real schur to complex directly instead of calling schur again. | |
T2,Q2,_ = schur(complex(T)) | |
return complex_sqrtm(T2, Q * Q2, n, cond) | |
end | |
for n = 1:num_blocks | |
if block_sizes[n] == 1 | |
j = blocks[n] | |
R[j,j] = sqrt(T[j,j]) | |
for m = n - 1:-1:1 | |
if block_sizes[m] == 1 | |
i = blocks[m] | |
r = zero(T[1]) | |
for k = i + 1:j - 1 | |
r += R[i,k]*R[k,j] | |
end | |
if T[i,j] != r | |
R[i,j] = (T[i,j] - r) / (R[i,i] + R[j,j]) | |
end | |
else | |
i = blocks[m]:blocks[m]+1 | |
r = T[i,j] | |
for k = i[end] + 1:j - 1 | |
r -= R[i,k]*R[k,j] | |
end | |
R[i,j] = (R[i,i] + R[j,j] * eye(2)) \ r | |
end | |
end | |
else | |
j = blocks[n]:blocks[n]+1 | |
# Square root of 2x2 block. | |
R[j,j] = sqrtm2x2(T[j,j]) | |
for m = n - 1:-1:1 | |
if block_sizes[m] == 1 | |
i = blocks[m] | |
r = T[i,j] | |
for k = i + 1:j[1] - 1 | |
r -= R[i,k]*R[k,j] | |
end | |
R[i,j] = ((R[i,i] * eye(2) + R[j,j])' \ r')' | |
else | |
i = blocks[m]:blocks[m]+1 | |
r = T[i,j] | |
for k = i[end] + 1:j[1] - 1 | |
r -= R[i,k]*R[k,j] | |
end | |
A = kron(eye(2), R[i,i]) + kron(R[j,j]', eye(2)) | |
R[i,j] = reshape(A \ r[:], 2, 2) | |
end | |
end | |
end | |
end | |
retmat = Q*R*Q' | |
if cond | |
alpha = norm(R)^2/norm(T) | |
return (all(imag(retmat) .== 0) ? real(retmat) : retmat), alpha | |
else | |
return (all(imag(retmat) .== 0) ? real(retmat) : retmat) | |
end | |
end | |
# Specialized sqrtm for real 2x2 matrices of the form [a b;c a] with b*c<0. | |
# Should give the same result as real(sqrtm(R)). | |
function sqrtm2x2(R::Matrix) | |
a = R[1,1] | |
b = R[1,2] | |
c = R[2,1] | |
A = sqrt(0.5 * (a + sqrt(a * a - b * c))) | |
B = b / (2 * A) | |
C = c / (2 * A) | |
return [A B;C A] | |
end | |
function complex_sqrtm(T::Matrix, Q::Matrix, n::Int, cond::Bool) | |
R = zeros(eltype(T), n, n) | |
for j = 1:n | |
R[j,j] = sqrt(T[j,j]) | |
for i = j - 1:-1:1 | |
r = zero(T[1]) | |
for k = i + 1:j - 1 | |
r += R[i,k]*R[k,j] | |
end | |
if T[i,j] != r | |
R[i,j] = (T[i,j] - r) / (R[i,i] + R[j,j]) | |
end | |
end | |
end | |
retmat = Q*R*Q' | |
if cond | |
alpha = norm(R)^2/norm(T) | |
return (all(imag(retmat) .== 0) ? real(retmat) : retmat), alpha | |
else | |
return (all(imag(retmat) .== 0) ? real(retmat) : retmat) | |
end | |
end | |
function rsqrtm(A::StridedMatrix, cond::Bool) | |
m, n = size(A) | |
if m != n error("Dimension Mismatch") end | |
T,Q,_ = schur(A) | |
if isreal(A) | |
return real_sqrtm(T, Q, n, cond) | |
else | |
return complex_sqrtm(T, Q, n, cond) | |
end | |
end | |
rsqrtm(A::StridedMatrix) = rsqrtm(A, false) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment