Skip to content

Instantly share code, notes, and snippets.

@GunnarFarneback
Created February 18, 2014 22:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GunnarFarneback/2c2a38b3715279660bfe to your computer and use it in GitHub Desktop.
Save GunnarFarneback/2c2a38b3715279660bfe to your computer and use it in GitHub Desktop.
Implementation of real matrix square root algorithm by Higham 1987.The code was written at the time of #2246 so parts of it look a bit old-fashioned now. For full performance several constructions must be devectorized.
function real_sqrtm(T::Matrix, Q::Matrix, n::Int, cond::Bool)
R = zeros(eltype(T), n, n)
num_blocks = n - sum(diag(T, -1) .!= 0)
blocks = zeros(num_blocks)
block_sizes = zeros(num_blocks)
negative_real_eigenvalue_found = false
i = 1
for j = 1:num_blocks
blocks[j] = i
if i < n && T[i+1,i] != 0
block_sizes[j] = 2
i += 2
else
block_sizes[j] = 1
if T[i,i] < 0
negative_real_eigenvalue_found = true
break
end
i += 1
end
end
if negative_real_eigenvalue_found
# FIXME: Convert real schur to complex directly instead of calling schur again.
T2,Q2,_ = schur(complex(T))
return complex_sqrtm(T2, Q * Q2, n, cond)
end
for n = 1:num_blocks
if block_sizes[n] == 1
j = blocks[n]
R[j,j] = sqrt(T[j,j])
for m = n - 1:-1:1
if block_sizes[m] == 1
i = blocks[m]
r = zero(T[1])
for k = i + 1:j - 1
r += R[i,k]*R[k,j]
end
if T[i,j] != r
R[i,j] = (T[i,j] - r) / (R[i,i] + R[j,j])
end
else
i = blocks[m]:blocks[m]+1
r = T[i,j]
for k = i[end] + 1:j - 1
r -= R[i,k]*R[k,j]
end
R[i,j] = (R[i,i] + R[j,j] * eye(2)) \ r
end
end
else
j = blocks[n]:blocks[n]+1
# Square root of 2x2 block.
R[j,j] = sqrtm2x2(T[j,j])
for m = n - 1:-1:1
if block_sizes[m] == 1
i = blocks[m]
r = T[i,j]
for k = i + 1:j[1] - 1
r -= R[i,k]*R[k,j]
end
R[i,j] = ((R[i,i] * eye(2) + R[j,j])' \ r')'
else
i = blocks[m]:blocks[m]+1
r = T[i,j]
for k = i[end] + 1:j[1] - 1
r -= R[i,k]*R[k,j]
end
A = kron(eye(2), R[i,i]) + kron(R[j,j]', eye(2))
R[i,j] = reshape(A \ r[:], 2, 2)
end
end
end
end
retmat = Q*R*Q'
if cond
alpha = norm(R)^2/norm(T)
return (all(imag(retmat) .== 0) ? real(retmat) : retmat), alpha
else
return (all(imag(retmat) .== 0) ? real(retmat) : retmat)
end
end
# Specialized sqrtm for real 2x2 matrices of the form [a b;c a] with b*c<0.
# Should give the same result as real(sqrtm(R)).
function sqrtm2x2(R::Matrix)
a = R[1,1]
b = R[1,2]
c = R[2,1]
A = sqrt(0.5 * (a + sqrt(a * a - b * c)))
B = b / (2 * A)
C = c / (2 * A)
return [A B;C A]
end
function complex_sqrtm(T::Matrix, Q::Matrix, n::Int, cond::Bool)
R = zeros(eltype(T), n, n)
for j = 1:n
R[j,j] = sqrt(T[j,j])
for i = j - 1:-1:1
r = zero(T[1])
for k = i + 1:j - 1
r += R[i,k]*R[k,j]
end
if T[i,j] != r
R[i,j] = (T[i,j] - r) / (R[i,i] + R[j,j])
end
end
end
retmat = Q*R*Q'
if cond
alpha = norm(R)^2/norm(T)
return (all(imag(retmat) .== 0) ? real(retmat) : retmat), alpha
else
return (all(imag(retmat) .== 0) ? real(retmat) : retmat)
end
end
function rsqrtm(A::StridedMatrix, cond::Bool)
m, n = size(A)
if m != n error("Dimension Mismatch") end
T,Q,_ = schur(A)
if isreal(A)
return real_sqrtm(T, Q, n, cond)
else
return complex_sqrtm(T, Q, n, cond)
end
end
rsqrtm(A::StridedMatrix) = rsqrtm(A, false)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment