Last active
April 4, 2019 16:28
-
-
Save briochemc/e73f323cf9bdce2be98f38c1d51aa451 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Load packages | |
using DualMatrixTools, DualNumbers, SparseArrays | |
# 2. Create a random dual-valued vector x | |
n = 10 | |
a, b = rand(n), rand(n) # real and dual parts | |
x = a + ε * b # dual-valued x | |
# 3. Create a sparse random dual-valued matrix M | |
A, B = sprand(n, n, 5/n), sprand(n, n, 5/n) | |
M = A + ε * B # dual-valued M | |
# 4. Factorize M | |
Mf = factorize(M) # stores B and the factors of A in Mf | |
# 5. Update only the dual part in Mf with 2B | |
Mf = factorize(Mf, 2M) # does not factorize 2A | |
# 6. Solve the dual-valued linear system, M * sol = x | |
sol = Mf \ x # forward and back substitution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Load packages | |
using HyperDualMatrixTools, HyperDualNumbers, SparseArrays | |
# 2. Create a random hyperdual-valued vector x | |
n = 10 | |
a, b, c, d = rand(n), rand(n), rand(n), rand(n) | |
x = a + ε₁ * b + ε₂ * c + ε₁ε₂ * d # hyperdual-valued x | |
# 3. Create a sparse random hyperdual-valued matrix M | |
A, B = sprand(n, n, 5/n), sprand(n, n, 5/n) | |
C, D = sprand(n, n, 5/n), sprand(n, n, 5/n) | |
M = A + ε₁ * B + ε₂ * C + ε₁ε₂ * D # hyperdual-valued M | |
# 4. Factorize M | |
Mf = factorize(M) # stores B, C, A, and the factors of A | |
# 5. Update only the non-real parts in Mf | |
Mf = factorize(Mf, 2M) # does not factorize 2A | |
# 6. Solve an hyperdual-valued linear system, M * sol = x | |
sol = Mf \ x # forward and back substitution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mutable struct buffer | |
p # p | |
s # s(p) | |
A # factors of ∇ₓF(s(p), p) | |
∇s # ∇s(p) | |
∇ₓf # ∇ₓf(s(p), p) | |
end | |
function update_buffer!(f, F, ∇ₓf, ∇ₓF, buffer, p) | |
if p ≠ buffer.p # only update if p has changed | |
buffer.p = p # update p | |
s, A, ∇s = buffer.s, buffer.A, buffer.∇s # unpack buffer | |
s .= solve(x -> F(x, p), x -> ∇ₓF(x, p), s) # update s (solver) | |
∇ₚF = hcat([D(F(s, p + ε * e(j))) for j in 1:m]) # (14) | |
A .= factorize(∇ₓF(s, p)) # update factors of ∇ₓF(s(p), p) | |
∇s .= A \ -∇ₚF # update ∇s via Eq. (5) | |
buffer.∇ₓf .= ∇ₓf(s, p) # update ∇ₓf(s(p), p) | |
end | |
end | |
function ∇f̂!(f, F, ∇ₓf, ∇ₓF, buffer, p) # gradient | |
update_buffer!(f, F, ∇ₓf, ∇ₓF, buffer, p) # update buffer | |
s, ∇s = buffer.s, buffer.∇s # unpack buffer | |
∇ₚf = [D(f(s, p + ε * e(j))) for j in 1:m] # (15) | |
return buffer.∇ₓf * ∇s + ∇ₚf # (4) | |
end | |
function ∇²f̂!(f, F, ∇ₓf, ∇ₓF, buffer, p) # Hessian | |
update_buffer!(f, F, ∇ₓf, ∇ₓF, buffer, p) # update buffer | |
s, A, ∇s = buffer.s, buffer.A, buffer.∇s # unpack buffer | |
A⁻ᵀ∇ₓfᵀ = vec(A' \ buffer.∇ₓf') # independent of (j,k) | |
out = zeros(m, m) # preallocate | |
for j in 1:m, k in j:m # Loop for (13) | |
pⱼₖ = p + ε₁ * e(j) + ε₂ * e(k) # Hyperdual p | |
xⱼₖ = s + ε₁ * ∇s * e(j) + ε₂ * ∇s * e(k) # Hyperdual x | |
out[j, k] = H(f(xⱼₖ, pⱼₖ)) - H(F(xⱼₖ, pⱼₖ))' * A⁻ᵀ∇ₓfᵀ # (13) | |
j ≠ k ? out[k, j] = out[j, k] : nothing # symmetry | |
end | |
return out | |
end | |
# Helper functions | |
e(j) = [i == j for i in 1:m] # j-th basis vector | |
D(x) = DualNumbers.dualpart.(x) # dual part | |
H(x) = HyperDualNumbers.ε₁ε₂part.(x) # hyperdual part |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment