Skip to content

Instantly share code, notes, and snippets.

@frapac
Created August 11, 2021 13:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frapac/41a0ac615c78ed47f7c4850952351257 to your computer and use it in GitHub Desktop.
Save frapac/41a0ac615c78ed47f7c4850952351257 to your computer and use it in GitHub Desktop.
abstract type AbstractKKTSystem end
"Assemble KKT matrix."
function compress! end
"Factorize KKT matrix with given linear solver."
function factorize! end
"Solve KKT system `Wx = b` with given linear solver."
function triangular_solve! end
"Multiply with entire KKT matrix."
function mul! end
"Multiply with Hessian"
function hess_mul! end
"Multiply with Jacobian"
function jac_mul! end
# Note: I am not sure about these two last callbacks
"Update KKT system with callbacks."
function update_hess! end
function update_jac! end
#=
SparseKKTSystem
=#
struct SparseKKTSystem{T, MT}
hess::StrideOneVector{T}
jac::StrideOneVector{T}
pr_diag::StrideOneVector{T}
du_diag::StrideOneVector{T}
l_diag::Union{Nothing,StrideOneVector{T}}
u_diag::Union{Nothing,StrideOneVector{T}}
l_lower::Union{Nothing,StrideOneVector{T}}
u_lower::Union{Nothing,StrideOneVector{T}}
aug_raw::SparseMatrixCOO{T,Int32}
aug_com::MT
jac_raw::SparseMatrixCOO{T,Int32}
jac_com::MT
end
function compress!(kkt::SparseKKTSystem)
# COO to MT conversion (update aug_com and jac_com)
end
function update_hess!(kkt::SparseKKTSystem, nlp, x, l, σ)
_w1l .= l .* con_scale
nlp.lag_hess!(kkt.hess, view(x,1:nlp.n), _w1l, σ)
return
end
function update_jac!(kkt::SparseKKTSystem, nlp, x)
nlp.con_jac!(kkt.jac, view(x,1:nlp.n))
jac[n_jac-ns+1:n_jac].=-1.
jac.*=con_jac_scale
return
end
#=
DenseKKTSystem
=#
struct DenseKKTSystem{T, VT, MT}
diag_hess::VT
hess::MT
jac::MT
pr_diag::VT
du_diag::VT
l_diag::Union{Nothing,StrideOneVector{Float64}}
u_diag::Union{Nothing,StrideOneVector{Float64}}
l_lower::Union{Nothing,StrideOneVector{Float64}}
u_lower::Union{Nothing,StrideOneVector{Float64}}
end
function compress!(kkt::DenseKKTSystem{T, VT, MT}) where {T, VT, MT}
n = size(solver.hess, 1)
for i in 1:n
kkt.hess[i, i] = kkt.diag_hess[i] + kkt.pr_diag[i]
end
end
function update_hess!(kkt::DenseKKTSystem{T, VT, MT}, nlp, x, l, σ) where {T, VT, MT}
_w1l .= l.*con_scale
nlp.lag_hess!(kkt.hess, view(x,1:nlp.n), _w1l, σ)
# update diagonal components
diag!(kkt.diag_hess, kkt.hess)
return
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment