Skip to content

Instantly share code, notes, and snippets.

@jverzani
Last active March 7, 2017 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jverzani/035114b20c011dfc23ea31d95e788fba to your computer and use it in GitHub Desktop.
Save jverzani/035114b20c011dfc23ea31d95e788fba to your computer and use it in GitHub Desktop.
Julia implementation of AMVW algorithm
module AMVW
## Julia implementation of
## Fast and backward stable computation of roots of polynomials
## https://lirias.kuleuven.be/bitstream/123456789/461961/1/TW654.pdf
## Derived from fortran code https://people.cs.kuleuven.be/~raf.vandebril/homepage/software/companion_qr.php?menu=5
## License is unclear, but hopefully can be MIT licensed
## TODO
## handle case on non convergence
## API work
## Get faster! seems to be 6x slower -- or more -- than just roots(p) and
## about as accurate
## did:
## * check norm in vals! so that rotatators have norm \approx 1
## * implement speed up for C_i = B_i in initial bulge chasing
## Utils
## take poly [p0, p1, ..., pn] and return
## [q_m-1, q_m-2, ..., q0], k
## where we trim of k roots of 0, and then make p monic, then reverese
function reverse_poly{T}(ps::Vector{T})
## trim any 0s from the end of ps
N = findlast(p -> p != zero(T), ps)
N == 0 && return(zeros(T,0), length(ps))
ps = ps[1:N]
## find 0s
K = findfirst(p -> p != zero(T), ps)
ps = ps[K:end]
qs = reverse(ps./ps[end])[2:end]
qs, K-1
end
using Compat
## Types
@compat abstract type CoreTransform{T} end
@compat abstract type Rotator{T} <: CoreTransform{T} end
function Base.ctranspose(r::Rotator)
c,s = r.xs
RealRotator([c,-s], r.i)
end
# the index is supeflous for now, and a bit of a hassle to keep immutable
# but might be of help later if twisting is approached. Shouldn't effect speed, but does mean 9N storage, not 6N
# so may be
struct RealRotator{T} <: Rotator{T}
xs::Vector{T}
i::Vector{Int}
end
Base.one{T}(::Type{RealRotator{T}})=RealRotator([one(T), zero(T)], zeros(Int,1))
Base.ones{T}(S::Type{RealRotator{T}}, N) = [one(S) for i in 1:N]
## get/set values
vals{T}(r::RealRotator{T}) = r.xs
function vals!{T}(r::RealRotator, c::T, s::T)
# normalize in case of roundoff errors
# but, using hueristic on 6.3 on square roots
nrmi = c^2 + s^2
nrmi = norm(nrmi - one(T)) >= 1e2*eps(T) ? inv(sqrt(nrmi)) : one(T)
r.xs[1] = c * nrmi
r.xs[2] = s * nrmi
end
idx(r::RealRotator) = r.i[1]
idx!(r::RealRotator, i::Int) = r.i[1] = i
Base.copy(a::RealRotator) = RealRotator(copy(a.xs), a.i)
function Base.copy!(a::RealRotator, b::RealRotator)
vals!(a, vals(b)...)
idx!(a, idx(b))
end
# Core transform is 2x2 matrix [a b; c d]
mutable struct RealTransform{T} <: CoreTransform{T}
xs::Vector{T} # [a b; c d]
i::Int
end
Base.ctranspose(r::RealTransform) = RealTransform(r.xs[[1,3,2,4]], r.i)
using Compat
mutable struct DoubleShiftCounter
zero_index::Int
start_index::Int
stop_index::Int
it_count::Int
end
@compat abstract type ShiftType{T} end
struct RealDoubleShift{T} <: ShiftType{T}
N::Int
POLY::Vector{T}
Q::Vector{RealRotator{T}}
Ct::Vector{RealRotator{T}} # We use C', not C here
B::Vector{RealRotator{T}}
REIGS::Vector{T}
IEIGS::Vector{T}
## reusable storage
U::RealRotator{T}
V::RealRotator{T}
W::RealRotator{T}
A::Matrix{T} # for parts of A = QR
Bk::Matrix{T} # for diagonal block
R::Matrix{T} # temp storage, sometimes R part of QR
e1::Vector{T} # eigen values e1, e2
e2::Vector{T}
ctrs::DoubleShiftCounter
end
function Base.convert{T}(::Type{RealDoubleShift}, ps::Vector{T})
N = length(ps)
RealDoubleShift(N, ps,
ones(RealRotator{T}, N), #Q
ones(RealRotator{T}, N), #Ct
ones(RealRotator{T}, N), #B
zeros(T, N), zeros(T, N), #EIGS
one(RealRotator{T}), one(RealRotator{T}), one(RealRotator{T}), #UVW
zeros(T, 3, 2),zeros(T, 3, 2),zeros(T, 3, 2), # A Bk R
zeros(T,2), zeros(T,2),
DoubleShiftCounter(0,1,N-1, 0)
)
end
## need to compute by hand in case we use big values
function Base.eigvals{T}(state::RealDoubleShift{T})
A = state.A[1:2, 1:2]
b = (A[1,1] + A[2,2]) # trace(A)
c = A[1,1] * A[2,2] - A[1,2] * A[2,1] # det(A)
discr = b^2 - 4.0 * c
if sign(discr) < 0
state.e1[1], state.e1[2] = b/2.0, sqrt(-discr)/2.0
state.e2[1], state.e2[2] = state.e1[1], -state.e1[2]
else
state.e1[2], state.e2[2] = zero(T), zero(T) # real
sdiscr = sqrt(discr)
u, v = b + sdiscr, b - sdiscr
if iszero(u) || iszero(v)
u, v = zero(T), zero(T)
elseif abs(u) > abs(v)
u = u / 2.0
v = c / u
else
v = v / 2.0
u = c / v
end
state.e1[1], state.e2[1] = u, v
end
end
##################################################
#
#
## Diagonostic code
function as_full{T}(a::RealRotator{T}, N::Int)
c,s = vals(a)
i = idx(a)
i < N || error("too big")
A = eye(T, N)
A[i:i+1, i:i+1] = [c -s; s c]
A
end
zero_out!{T}(A::Array{T}, tol=1e-12) = A[abs.(A) .<= tol] = zero(T)
## diagnostic
## create Full matrix from state object. For diagnostic purposes.
function Base.full{T}(state::RealDoubleShift{T}, what=:A)
N = state.N
Q = as_full(state.Q[1],N+1); for i in 2:N Q = Q * as_full(state.Q[i],N+1) end
Ct = as_full(state.Ct[1], N+1); for i in 2:N Ct = as_full(state.Ct[i],N+1)*Ct end
B = as_full(state.B[1],N+1); for i in 2:N B = B * as_full(state.B[i],N+1) end
x = -vcat(state.POLY[2:N], state.POLY[1], one(T))
alpha = norm(x)
e1 = zeros(T, N+1); e1[1]=one(T)
en = zeros(T, N+1); en[N] = one(T)
y = alpha * en
en1 = zeros(T, N+1); en1[N+1] = one(T)
rho = en1' * Ct * e1
yt = vec(-1/rho * en1' * Ct * B)
yt[abs.(yt) .< 1e-12] = 0 # tidy
# A = Q * Ct * (B + e1 * y')
W = (B + e1 * yt')
zero_out!(W)
what == :W && return W
R = Ct * W
zero_out!(R)
what == :R && return R
A = Q * Ct * (B + e1 * yt')
zero_out!(A)
A
end
# simple graphic to show march of algorithm
function show_status{T}(state::RealDoubleShift{T})
x = fill(".", state.N+2)
x[state.ctrs.zero_index+1] = "α"
x[state.ctrs.start_index+1] = "x"
x[state.ctrs.stop_index+2] = "Δ"
println(join(x, ""))
end
##
rotm(a,b) = [a -b; b a]
##
##################################################
# rotations; find values
# Real Givens
# This subroutine computes c and s such that,
#
# [c -s] * [a, b] = [0]; c^2 + s^2 = 1
#
# and
#
# r = sqrt(|a|^2 + |b|^2).
#
# XXX seems faster to just return r, then not
function givensrot{T <: Real}(a::T,b::T, donorm=Val{false})
iszero(b) && return (sign(a) * one(T), zero(T), abs(a))
iszero(a) && return(zero(T), sign(b) * one(T), abs(b))
r = hypot(a,b)
c, s = a/r, b/r
return(c,-s,r)
end
#### Operations on [,[ terms
## The zero_index and stop_index+1 point at "P" matrices; RealRotator(p,0) with p^2 = 1 (better name for these matrices?)
##
## We have pflip for moving P_i * R_{i+1} -> R'_{i+1} * P_i (solving R'_{i+1} = P_i * R_{i+1} * P_i
## basically just R(a, p*b)
##
function pflip{T}(a::RealRotator{T}, p=one(T))
u,v = vals(a)
vals!(a, u, sign(p)*v)
end
# get p from rotator which is RR(1,0) or RR(-1, 0)
function getp{T}(a::RealRotator{T})
c, s = vals(a)
norm(s) <= 4eps(T) || error("a is not a 'P' matrix")
sign(c)
end
## fuse combines two rotations into one, :left updates a, :right updates b
function fuse{T}(a::RealRotator{T}, b::RealRotator{T}, dir=Val{:right})
idx(a) == idx(b) || error("can't fuse")
ac, as = vals(a)
bc, bs = vals(b)
u,v = ac * bc - as * bs, ac * bs + as *bc
if dir == Val{:left}
vals!(a, u, v)
else
vals!(b, u, v)
end
end
# Turnover: Q1 Q3 | x x x | Q1
# Q2 = | x x x | = Q3 Q2 <-- misfit=3 Q1, Q2 shift;
# | x x x |
#
# misfit is Val{:right} for <-- (right to left turnover), Val{:left} for -->
#
function turnover{T}(Q1::RealRotator{T}, Q2::RealRotator{T}, Q3::RealRotator{T}, misfit=Val{:right})
i,j,k = idx(Q1), idx(Q2), idx(Q3)
(i == k) || error("Need to have a turnover up down up or down up down: have i=$i, j=$j, k=$k")
abs(j-i) == 1 || error("Need to have |i-j| == 1")
c1,s1 = vals(Q1)
c2,s2 = vals(Q2)
c3,s3 = vals(Q3)
# initialize c4 and s4
a = c1*c2*s3 + s1*c3
b = s2*s3
# check norm([a,b]) \approx 1
c4, s4, nrm = givensrot(a,b)#, Val{true})
# initialize c5 and s5
a = c1*c3 - s1*c2*s3
b = nrm
# check norm([a,b]) \approx 1
c5, s5, tmp = givensrot(a,b)
# second column
u = -c1*s3 - s1*c2*c3
v = c1*c2*c3 - s1*s3
w = s2 * c3
a = c4*c5*v - s4*c5*w + s5*u
b = c4*w + s4*v
c6, s6, tmp = givensrot(a,b)
## for misfit=false move --> (misfit starts on left), true <-- (misfit starts on right)
if misfit == Val{:left}
vals!(Q2, c4, -s4)
vals!(Q3, c5, -s5)
vals!(Q1, c6, -s6)
idx!(Q1, j) # misfit gets shifted
else
vals!(Q3, c4, -s4)
vals!(Q1, c5, -s5)
vals!(Q2, c6, -s6)
idx!(Q3, j) # misfit
end
end
### Related to decompostion QR into QC(B + ...)
## fill A[k:k+2, k:k+1] k in 2:N
## updates A
##
# We look for r_j,k. Depending on |j-k| there are different amounts of work
# we have wk = (B + e1 y^t) * ek = B*ek + e1 yk; we consider B * ek only B1 ... Bk ek applies
#
# julia> @vars bk1 bk2 bj1 bj2 bi1 bi2
# julia> rotm(bi1, bi2, 1, 4) * rotm(bj1, bj2, 2, 4) * rotm(bk1, bk2, 3, 4) * [0, 0, 1, 0] # B_{k-2} * B_{k-1} * B_k * ek = W
# 4-element Array{SymPy.Sym,1}
# ⎡bi₂⋅bj₂⋅bk₁ ⎤
# ⎢ ⎥
# ⎢-bi₁⋅bj₂⋅bk₁⎥
# ⎢ ⎥
# ⎢ bj₁⋅bk₁ ⎥
# ⎢ ⎥
# ⎣ bk₂ ⎦
# which gives W = [what_{k-2} w_{k-1} w_k w_{k+1}]
# For rkk, we have Ck * W = [rkk, 0]
# @vars ck1 ck2 what w1
# u = rotm(ck1, ck2, 1,2) * [what, w1]
# u[1](what => solve(u[2], what)[1]) |> simplify
# ⎛ 2 2⎞
# -w₁⋅⎝c₁ + c₂ ⎠
# ──────────────── # this is rkk = -w1/c2 = -bk2/ck2
# c₂
#
# For r[k-1, k] we need to do more work. We need [what_{k-1}, w_k, w_{k+1}], where w_k, w_{k+1} found from B values as above.
#
# julia> @vars ck1 ck2 cj1 cj2 what w w1
# (ck1, ck2, cj1, cj2, what, w, w1)
# julia> u = rotm(ck1, ck2, 2, 3) * rotm(cj1, cj2, 1, 3) * [what, w, w1] # C^*_{k} * C^*{k-1} * W = [r_{k-1,k}, r_{k,k}, 0]
# 3-element Array{SymPy.Sym,1}
# ⎡ cj₁⋅ŵ - cj₂⋅w ⎤
# ⎢ ⎥
# ⎢cj₁⋅ck₁⋅w + cj₂⋅ck₁⋅ŵ - ck₂⋅w₁ ⎥
# ⎢ ⎥
# ⎣cj₁⋅ck₂⋅w + cj₂⋅ck₂⋅ŵ + ck₁⋅w₁ ⎦
# julia> u[1](what => solve(u[3], what)[1]) |> simplify
# 2
# cj₁ ⋅w cj₁⋅ck₁⋅w₁
# - ────── - ────────── - cj₂⋅w
# cj₂ cj₂⋅ck₂
#
# For r_{k-2,k} we need to reach back one more step
# C^*_{k} * C^*{k-1} * C^*_{k-2} W = [r_{k-2,k} r_{k-1,k}, r_{k,k}, 0]
#
# julia> @vars ck1 ck2 cj1 cj2 ci1 ci2 what wm1 w w1
# julia> u = rotm(ck1, ck2, 3, 4) * rotm(cj1, cj2, 2, 4) * rotm(ci1, ci2, 1, 4) * [what, wm1, w, w1]
# julia> u[1](what => solve(u[4], what)[1]) |> simplify
# 2
# ci₁ ⋅wm₁ ci₁⋅cj₁⋅w ci₁⋅ck₁⋅w₁
# - ──────── - ───────── - ─────────── - ci₂⋅wm₁
# ci₂ ci₂⋅cj₂ ci₂⋅cj₂⋅ck₂
function diagonal_block{T}(state::RealDoubleShift{T}, k)
k >= 2 && k <= state.N || error("$k not in [2,n]")
A = state.A
R = state.R
if k == 2
# here we only need [r11 r12; 0 r22], so only use top part of R
for j in 1:2
ck1, ck2 = vals(state.Ct[k - (2-j)])
w1 = vals(state.B[k - (2-j)])[2]
R[j,j] = - w1 / ck2
end
cj1, cj2 = vals(state.Ct[k-1])
ck1, ck2 = vals(state.Ct[k])
w = vals(state.B[k-1])[1] * vals(state.B[k])[1]
w1 = vals(state.B[k])[2]
val = -(cj1^2*w)/cj2 - (cj1 * ck1 * w1) / (cj2 * ck2) - cj2 * w
R[1,2] = val
q11, q12 = vals(state.Q[k-1]); q21, q22 = vals(state.Q[k])
A[1,1] = q11 * R[1,1]
A[1,2] = q11 * R[1,2] - q12 * q21 * R[2,2]
A[2,1] = q12 * R[1,1]
A[2,2] = q11 * q21 * R[2,2] + q12 * R[1,2]
else ## Need condition on N, as Bn is Bn*Zn
## R = R[k-2:k, k-1:k]
Qk_2, Qk_1, Qk = state.Q[k-2].xs, state.Q[k-1].xs, state.Q[k].xs
wk_1, wk, wk1 = state.B[k-2].xs[2], state.B[k-1].xs[2], state.B[k].xs[2]
# r_kk
for j in 1:2
K = k - (2-j)
ck1, ck2 = vals(state.Ct[K])
w1 = vals(state.B[K])[2]
R[j+1,j] = - w1 / ck2
end
for j in 1:2
K = k - (2-j)
cj1, cj2 = vals(state.Ct[K-1])
ck1, ck2 = vals(state.Ct[K])
w = vals(state.B[K-1])[1] * vals(state.B[K])[1]
w1 = vals(state.B[K])[2]
val = -(cj1^2*w)/cj2 - (cj1 * ck1 * w1) / (cj2 * ck2) - cj2 * w
R[j,j] = val
end
# last is R[1,2]
wm1 = -vals(state.B[k-2])[1] * vals(state.B[k-1])[2] * vals(state.B[k])[1]
w = vals(state.B[k-1])[1] * vals(state.B[k])[1]
w1 = vals(state.B[k])[2]
ci1, ci2 = vals(state.Ct[k-2])
cj1, cj2 = vals(state.Ct[k-1])
ck1, ck2 = vals(state.Ct[k])
R[1,2] = -(ci1^2 * wm1/ci2) - (ci1 * cj1 * w) / (ci2 * cj2) - (ci1 * ck1 * w1) / (ci2 * cj2 * ck2) - ci2*wm1
# A is Q*R, but not all Q contribute
# This is for k = 5
# julia> A[4,4]
# q₃ ₁⋅q₄ ₁⋅r₄₄ + q₃ ₂⋅r₃₄
# julia> A[4,5]
# q₃ ₁⋅q₄ ₁⋅r₄₅ - q₃ ₁⋅q₄ ₂⋅q₅ ₁⋅r₅₅ + q₃ ₂⋅r₃₅
# julia> A[5,4]
# q₄ ₂⋅r₄₄
# julia> A[5,5]
# q₄ ₁⋅q₅ ₁⋅r₅₅ + q₄ ₂⋅r₄₅
A[1,1] = Qk_2[1] * Qk_1[1] * R[2,1] + Qk_2[2] * R[1,1]
A[1,2] = Qk_2[1] * Qk_1[1] * R[2,2] - Qk_2[1] * Qk_1[2] * Qk[1] * R[3,2] + Qk_2[2] * R[1,2]
A[2,1] = Qk_1[2] * R[2,1]
A[2,2] = Qk_1[1] * Qk[1] * R[3,2] + Qk_1[2] * R[2,2]
end
end
## Deflation
function check_deflation{T}(state::RealDoubleShift{T}, tol = eps(T))
# println([u.xs[2] for u in state.Q])
for k in state.ctrs.stop_index:-1:state.ctrs.start_index
if abs(vals(state.Q[k])[2]) <= tol
deflate(state, k)
return
end
end
end
# deflate a term
function deflate{T}(state::RealDoubleShift{T}, k)
# make a P matrix
vals!(state.Q[k], getp(state.Q[k]), zero(T))
# shift zero counter
state.ctrs.zero_index = k # points to a matrix Q[k] either RealRotator(-1, 0) or RealRotator(1, 0)
state.ctrs.start_index = k + 1
# reset counter
state.ctrs.it_count = 1
end
## Bulge chasing
function create_bulge{T}(state::RealDoubleShift{T})
if state.ctrs.it_count == 0
t = rand() * 2pi
re1, ie1 = cos(t), sin(t)
re2, ie2 = re1, -ie1
vals!(state.U, re1, ie1); idx!(state.U, state.ctrs.start_index)
vals!(state.V, re2, ie2); idx!(state.V, state.ctrs.start_index + 1)
else
# compute (A-rho1) * (A - rho2) * e_1
# get first columns of A
Bk = state.Bk
# find e1, e2
diagonal_block(state, state.ctrs.stop_index+1)
eigvals(state)
l1r, l1i = state.e1
l2r, l2i = state.e2
# find first part of A[1:3, 1:2]
diagonal_block(state, state.ctrs.start_index+1)
Bk[1:2, 1:2] = state.A[1:2, 1:2]
if state.ctrs.start_index + 2 <= state.N # (Why this condition?)
diagonal_block(state, state.ctrs.start_index+2)
Bk[3,2] = state.A[2, 1]
end
# make first three elements of c1,c2,c3
# c1 = real(-l1i⋅l2i + ⅈ⋅l1i⋅l2r - ⅈ⋅l1i⋅t₁₁ + ⅈ⋅l1r⋅l2i + l1r⋅l2r - l1r⋅t₁₁ - ⅈ⋅l2i⋅t₁₁ - l2r⋅t₁₁ + t₁₁^2 + t₁₂⋅t₂₁)
# c2 = real(-ⅈ⋅l1i⋅t₂₁ - l1r⋅t₂₁ - ⅈ⋅l2i⋅t₂₁ - l2r⋅t₂₁ + t₁₁⋅t₂₁ + t₂₁⋅t₂₂)
# c3 = real(t₂₁⋅t₃₂)
c1 = -l1i * l2i + l1r*l2r -l1r*Bk[1,1] -l2r * Bk[1,1] + Bk[1,1]^2 + Bk[1,2] * Bk[2,1]
c2 = -l1r * Bk[2,1] - l2r * Bk[2,1] + Bk[1,1]* Bk[2,1] + Bk[2,1] * Bk[2,2]
c3 = Bk[2,1] * Bk[3,2]
c,s, nrm = givensrot(c2, c3, Val{true})
vals!(state.V, c,-s)
idx!(state.V, state.ctrs.start_index + 1)
c,s, tmp = givensrot(c1, nrm)
vals!(state.U, c, -s)
idx!(state.U, state.ctrs.start_index)
end
end
## make W on left side
#
# initial Q0
# we do turnover U1' Q1 --> U1' --> U1' --> Q1
# V1' Q2 Q1 V1' Q2 Q1 (V1'Q2) W1 Q2
# With this, W will be kept on the left side until the last step, U,V
# move through left to right by one step, right to left by unitariness
#
# Q0 Q0 Q0 Q0
# U1' Q1 U1' Q1* -> U1 --> U1*
# V1' Q3 -> V1' Q3 Q1** V1' Q3 W (V1'Q3)
#
# Q0 is (p,0) rotator, p 1 or -1. We have
# Q0 --> Q0
# R (r, pr2)
function prepare_bulge{T}(state::RealDoubleShift{T})
# println("prepare bulge")
# as_full(V', N+1)* as_full(U', N+1)* full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println
k = state.ctrs.start_index
Ut = state.U'; Vt = state.V'
copy!(state.W, state.Q[k])
p = k == 1 ? one(T) : state.Q[k-1].xs[1] # zero index implies Q0 = RR(1,0) or RR(-1,0)
pflip(state.W, p)
turnover(Ut, Vt, state.W, Val{:right})
fuse(Vt, state.Q[k+1], Val{:right}) # V' Q3
pflip(Ut, p)
vals!(state.Q[k], vals(Ut)...) # Ut.xs[1], p * Ut.xs[2])
end
function chase_bulge{T}(state::RealDoubleShift{T}, tr)
# println(" begin chase at level $(state.V.i)")
# as_full(state.W, state.N+1)* full(state) * as_full(state.V, state.N+1) * as_full(state.U, state.N+1) |> eigvals |> println
# one step
i = idx(state.V)
#println("k=$(state.ctrs.stop_index); Is Qk+1 identity? ", state.Q[state.ctrs.stop_index+1].xs[2])
## The i < tr is the speed up described in Exploting C_i = B_i in early stages
while i < state.ctrs.stop_index # loops from start_index to stop_index - 1
if i <= tr
turnover(state.B[i], state.B[i+1], copy(state.V))
turnover(state.B[i-1], state.B[i], copy(state.U))
for k in -1:1
a,b = vals(state.B[i+k])
vals!(state.Ct[i+k], a, -b) # using copy!(Ct, B') is slower
end
idx!(state.U, i-1); idx!(state.V, i)
else
turnover(state.B[i], state.B[i+1], state.V)
turnover(state.Ct[i+1], state.Ct[i], state.V)
j = idx(state.U)
turnover(state.B[j], state.B[j+1], state.U)
turnover(state.Ct[j+1], state.Ct[j], state.U)
end
turnover(state.Q[i], state.Q[i+1], state.V)
turnover(state.Q[i-1], state.Q[i], state.U)
turnover(state.W, state.V, state.U, Val{:left})
i = idx(state.V)
end
# println("end chase")
# as_full(state.W, state.N+1)* full(state) * as_full(state.V, state.N+1) * as_full(state.U, state.N+1) |> eigvals |> println
end
function absorb_bulge{T}(state::RealDoubleShift{T})
# println("absorb 0")
# as_full(state.W, state.N+1) * full(state) * as_full(state.V, state.N+1) * as_full(state.U, N+1) |> eigvals |> println
# first V goes through B, C then fuses with Q
i = idx(state.V)
turnover(state.B[i], state.B[i+1], state.V, Val{:right})
turnover(state.Ct[i+1], state.Ct[i], state.V)
## We may be fusing Q P --> (Q')
# RR(-1,0) RR(-1,0)
#
p = getp(state.Q[i+1])
pflip(state.V, p)
fuse(state.Q[i], state.V, Val{:left}) # fuse Q*V -> Q
# println("absorb 1")
# as_full(state.W, state.N+1) * full(state) * as_full(state.U, state.N+1) |> eigvals |> println
# Then bring U through B, C, and Q to fuse with W
j = idx(state.U)
turnover(state.B[j], state.B[j+1], state.U)
turnover(state.Ct[j+1], state.Ct[j], state.U)
turnover(state.Q[j], state.Q[j+1], state.U)
fuse(state.W, state.U, Val{:right})
# println("absorb 2")
# as_full(state.U, state.N+1) * full(state) |> eigvals |> println
# similarity transformation, bring through then fuse with Q
j = idx(state.U)
turnover(state.B[j], state.B[j+1], state.U, Val{:right})
turnover(state.Ct[j+1], state.Ct[j], state.U)
p = getp(state.Q[j+1])
pflip(state.U, p)
fuse(state.Q[j], state.U, Val{:left})
# println("absorb final")
# full(state) |> eigvals |> println
end
function bulge_step{T}(state::RealDoubleShift{T}, tr)
create_bulge(state)
#println("bulge created")
#as_full(U', N+1)* as_full(V', N+1)* full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println
prepare_bulge(state)
#println("prepare bulge, make W")
#as_full(W, N+1) * full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println
chase_bulge(state, tr)
absorb_bulge(state)
#full(state) |> eigvals |> println
end
function init_state{T}(state::RealDoubleShift{T})
N, ps= state.N, state.POLY
Q, Ct, B = state.Q, state.Ct, state.B
for ii = 1:(N-1)
vals!(Q[ii], zero(T), one(T)); idx!(Q[ii], ii)
end
vals!(Q[N], one(T), zero(T)); idx!(Q[N], N)
# play with signs here.
s = iseven(N) ? one(T) : -one(T)
a, b, temp = givensrot(-ps[N], -one(T), Val{true})
vals!(Ct[N], -s*a, -s*b); idx!(Ct[N], N)
vals!(B[N], -b, -a); idx!(B[N], N)
for ii in 2:N
a, b, temp = givensrot(-ps[ii-1], temp, Val{true})
vals!(Ct[N-ii + 1], a, -b); idx!(Ct[N-ii+1], N-ii+1)
vals!(B[N-ii + 1], a, b); idx!(B[N-ii+1], N-ii+1)
end
end
## Main algorithm of AMV&W
function AMVW_algorithm{T}(state::RealDoubleShift{T})
it_max = 30 * state.N
kk = 0
tr = state.N - 2
while kk <= it_max
## finished up!
state.ctrs.stop_index <= 0 && return
check_deflation(state)
kk += 1
# show_status(state)
k = state.ctrs.stop_index
if state.ctrs.stop_index - state.ctrs.zero_index >= 2
bulge_step(state, tr)
state.ctrs.it_count += 1
tr -= 2
elseif state.ctrs.stop_index - state.ctrs.zero_index == 1
diagonal_block(state, k + 1)
eigvals(state)
state.REIGS[k], state.IEIGS[k] = state.e2
state.REIGS[k+1], state.IEIGS[k+1] = state.e1
diagonal_block(state, 2)
if state.ctrs.stop_index == 2
diagonal_block(state, 2)
state.REIGS[1] = state.A[1,1]
end
state.ctrs.zero_index = 0
state.ctrs.start_index = 1
state.ctrs.stop_index = state.ctrs.stop_index - 2
elseif state.ctrs.stop_index - state.ctrs.zero_index == 0
diagonal_block(state, state.ctrs.stop_index + 1)
e1, e2 = state.A[1,1], state.A[2,2] # eigvals(Bk[1:2, 1:2])
if state.ctrs.stop_index == 1
state.REIGS[state.ctrs.stop_index] = e1
state.REIGS[state.ctrs.stop_index+1] = e2
state.ctrs.stop_index = 0
else
state.REIGS[state.ctrs.stop_index+1] = e2
k = state.ctrs.stop_index
state.ctrs.zero_index = 0
state.ctrs.start_index = 1
state.ctrs.stop_index = k - 1
end
end
end
println("error -- didn't work!!! Take care of this")
end
"""
Use AMVW algorithm doubleshift alorithm to find roots
of the polynomial p_0 + p_1 x + p_2 x^2 + ... + p_n x^n encoded as
`[p_0, p_1, ..., p_n]` (the same ordering used by `Polynomials`).
Returns an object of type `RealDoubleShift`.
Example: API needs work!
```
using Polynomials
x = variable()
p = poly(x - i/10 for i in 5:10)
state = amvw(p.a)
complex.(state.REIGS, state.IEIGS)
```
"""
function amvw{T <: Real}(ps::Vector{T})
qs, k = reverse_poly(ps)
# k is number of 0 factors
n = length(qs)
n ==0 && error("0 polynomial")
state = RealDoubleShift(qs)
init_state(state)
AMVW_algorithm(state)
state
end
end
# test
using AMVW, Polynomials
A = AMVW
## some interface for polynomials
function amvw(p::Poly)
qs, k = A.reverse_poly(p.a)
state = A.RealDoubleShift(qs)
A.init_state(state)
A.AMVW_algorithm(state)
complex.(state.REIGS, state.IEIGS)
end
## quick hack -- doesn't work with complex, big, ...
function residual_check(p::Poly)
# r1 = |P(lambda)/P'(lambda)|
# r2 = |P(lambda)/P'(lambda)/lambda|
# r3 = ||Cv-lambda v||/||C||/||v||
state = A.amvw(p.a)
#lambdas = complex.(state.REIGS, state.IEIGS)
lambdas = state.REIGS
lambdas = lambdas - p.(lambdas) ./ polyder(p).(lambdas)
sort!(lambdas) # complex???
r1 = norm(p.(lambdas) ./ polyder(p).(lambdas))
r2 = norm(p.(lambdas) ./ polyder(p).(lambdas) ./ lambdas)
A.init_state(state)
C = full(state)[1:end-1, 1:end-1]
es = eigfact(C)
vals = es.values
vecs = es.vectors
ind = sortperm(vals)
vecs = vecs[:,ind]
r3 = 0.0
for i in eachindex(vals)
lambda = lambdas[i]
v = vecs[:,i]
r3 += norm(C*v - lambda * v) / norm(C) / norm(v)
end
## now for roots
lambdas = sort(roots(p))
lambdas = lambdas - p.(lambdas) ./ polyder(p).(lambdas)
rr1 = norm(p.(lambdas) ./ polyder(p).(lambdas))
rr2 = norm(p.(lambdas) ./ polyder(p).(lambdas) ./ lambdas)
rr3 = 0.0
for i in eachindex(vals)
lambda = lambdas[i]
v = vecs[:,i]
rr3 += norm(C*v - lambda * v) / norm(C) / norm(v)
end
(r1/rr1, r2/rr2, r3/rr3)
end
## Tests
## Time compared to `roots` -- slower
##
# julia> n = 15; p = poly(linspace(1/n, 1, n));
# julia> using BenchmarkTools; @benchmark roots(p)
# BenchmarkTools.Trial:
# memory estimate: 41.58 KiB
# allocs estimate: 51
# --------------
# minimum time: 189.171 μs (0.00% GC)
# median time: 191.314 μs (0.00% GC)
# mean time: 205.429 μs (2.56% GC)
# maximum time: 3.588 ms (88.18% GC)
# --------------
# samples: 10000
# evals/sample: 1
# time tolerance: 5.00%
# memory tolerance: 1.00%
# #
# julia> @benchmark amvw(p)
# BenchmarkTools.Trial:
# memory estimate: 48.06 KiB
# allocs estimate: 1068
# --------------
# minimum time: 490.720 μs (0.00% GC)
# median time: 797.224 μs (0.00% GC)
# mean time: 816.253 μs (1.21% GC)
# maximum time: 6.539 ms (81.94% GC)
# --------------
# samples: 5915
# evals/sample: 1
# time tolerance: 5.00%
# memory tolerance: 1.00%
## and for
# julia> n = 25; p = poly(linspace(1/n, 1, n));
# julia> @benchmark roots(p)
# BenchmarkTools.Trial:
# memory estimate: 46.44 KiB
# allocs estimate: 60
# --------------
# minimum time: 301.662 μs (0.00% GC)
# median time: 310.948 μs (0.00% GC)
# mean time: 324.381 μs (1.90% GC)
# maximum time: 4.811 ms (90.75% GC)
# --------------
# samples: 10000
# evals/sample: 1
# time tolerance: 5.00%
# memory tolerance: 1.00%
# julia> @benchmark amvw(p)
# BenchmarkTools.Trial:
# memory estimate: 64.08 KiB
# allocs estimate: 1640
# --------------
# minimum time: 1.114 ms (0.00% GC)
# median time: 1.859 ms (0.00% GC)
# mean time: 1.956 ms (0.94% GC)
# maximum time: 8.569 ms (71.05% GC)
# --------------
# samples: 2512
# evals/sample: 1
# time tolerance: 5.00%
# memory tolerance: 1.00%
## scalen in n is not linear, not quadratic
# ns = [4,8,16,32,64,128,256]
# ts = zeros(length(ns))
# for i in eachindex(ns)
# n = ns[i]
# p = poly(rand(n))
# ts[i] = time()
# amvw(p)
# ts[i] = time() - ts[i]
# end
# julia> ts[2:end] ./ ts[1:end-1]
# 6-element Array{Float64,1}:
# 0.81346
# 3.22482
# 3.85639
# 2.87251
# 4.80438
# 4.24259
## Accuracy
# for n in 5:10
# p = poly(1.0:n)
# print("n=$n: "); println(residual_check(p))
# end
# ## ratios of residuasl of amvw to roots
# julia> for n in 5:10
# p = poly(1.0:n)
# print("n=$n: "); println(residual_check(p))
# end
# n=5: (1.2249137731278716, 1.373964400725604, 0.8348154805166964)
# n=6: (0.973487711587202, 0.8271239165425018, 0.9888450594647631)
# n=7: (0.6176463305637532, 0.5859554825199832, 1.00385484319229)
# n=8: (0.9089421995514068, 0.8834377341004336, 0.9318743185837587)
# n=9: (1.2426448919951487, 1.3151521779378252, 0.9904264553424764)
# n=10: (2.280076313005534, 2.30722853091966, 0.9395012114789933)
## testing
wpoly(n) = poly(1.0:n)
amvw(wpoly(10))
#residual_check(wpoly(10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment