Created
August 11, 2021 13:55
-
-
Save frapac/41a0ac615c78ed47f7c4850952351257 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abstract type AbstractKKTSystem end | |
"Assemble KKT matrix." | |
function compress! end | |
"Factorize KKT matrix with given linear solver." | |
function factorize! end | |
"Solve KKT system `Wx = b` with given linear solver." | |
function triangular_solve! end | |
"Multiply with entire KKT matrix." | |
function mul! end | |
"Multiply with Hessian" | |
function hess_mul! end | |
"Multiply with Jacobian" | |
function jac_mul! end | |
# Note: I am not sure about these two last callbacks | |
"Update KKT system with callbacks." | |
function update_hess! end | |
function update_jac! end | |
#= | |
SparseKKTSystem | |
=# | |
struct SparseKKTSystem{T, MT} | |
hess::StrideOneVector{T} | |
jac::StrideOneVector{T} | |
pr_diag::StrideOneVector{T} | |
du_diag::StrideOneVector{T} | |
l_diag::Union{Nothing,StrideOneVector{T}} | |
u_diag::Union{Nothing,StrideOneVector{T}} | |
l_lower::Union{Nothing,StrideOneVector{T}} | |
u_lower::Union{Nothing,StrideOneVector{T}} | |
aug_raw::SparseMatrixCOO{T,Int32} | |
aug_com::MT | |
jac_raw::SparseMatrixCOO{T,Int32} | |
jac_com::MT | |
end | |
function compress!(kkt::SparseKKTSystem) | |
# COO to MT conversion (update aug_com and jac_com) | |
end | |
function update_hess!(kkt::SparseKKTSystem, nlp, x, l, σ) | |
_w1l .= l .* con_scale | |
nlp.lag_hess!(kkt.hess, view(x,1:nlp.n), _w1l, σ) | |
return | |
end | |
function update_jac!(kkt::SparseKKTSystem, nlp, x) | |
nlp.con_jac!(kkt.jac, view(x,1:nlp.n)) | |
jac[n_jac-ns+1:n_jac].=-1. | |
jac.*=con_jac_scale | |
return | |
end | |
#= | |
DenseKKTSystem | |
=# | |
struct DenseKKTSystem{T, VT, MT} | |
diag_hess::VT | |
hess::MT | |
jac::MT | |
pr_diag::VT | |
du_diag::VT | |
l_diag::Union{Nothing,StrideOneVector{Float64}} | |
u_diag::Union{Nothing,StrideOneVector{Float64}} | |
l_lower::Union{Nothing,StrideOneVector{Float64}} | |
u_lower::Union{Nothing,StrideOneVector{Float64}} | |
end | |
function compress!(kkt::DenseKKTSystem{T, VT, MT}) where {T, VT, MT} | |
n = size(solver.hess, 1) | |
for i in 1:n | |
kkt.hess[i, i] = kkt.diag_hess[i] + kkt.pr_diag[i] | |
end | |
end | |
function update_hess!(kkt::DenseKKTSystem{T, VT, MT}, nlp, x, l, σ) where {T, VT, MT} | |
_w1l .= l.*con_scale | |
nlp.lag_hess!(kkt.hess, view(x,1:nlp.n), _w1l, σ) | |
# update diagonal components | |
diag!(kkt.diag_hess, kkt.hess) | |
return | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment