Created
February 10, 2019 17:12
-
-
Save msakai/9a5c31882a81ec83f60e046d2db4f1db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# numpy/scipy translation of https://github.com/locuslab/qpth/blob/5485219028a7687b76107c8431625aaedfd7bc36/qpth/solvers/pdipm/single.py | |
import numpy as np | |
import scipy | |
import scipy.linalg | |
# TODO: Add more comments describing the math here. | |
# https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf | |
def get_sizes(G, A=None): | |
if A is None: | |
neq = 0 | |
else: | |
new = A.shape[0] | |
return (G.shape[0], G.shape[1], neq, 1) | |
def forward(inputs_i, Q, G, A, b, h, U_Q, U_S, R, verbose=False): | |
""" | |
U_Q, U_S, R = pre_factor_kkt(Q, G, A) | |
""" | |
nineq, nz, neq, _ = get_sizes(G, A) | |
# find initial values | |
d = np.ones(nineq, dtype=Q.dtype) | |
nb = -b if b is not None else None | |
factor_kkt(U_S, R, d) | |
x, s, z, y = solve_kkt( | |
U_Q, d, G, A, U_S, | |
inputs_i, np.zeros(nineq, dtype=Q.dtype), -h, nb) | |
if np.min(s) < 0: | |
s -= np.min(s) - 1 | |
if np.min(z) < 0: | |
z -= np.min(z) - 1 | |
prev_resid = None | |
for i in range(20): | |
# affine scaling direction | |
rx = (A.T.dot(y) if neq > 0 else 0.) + G.T.dot(z) + Q.dot(x) + inputs_i | |
rs = z | |
rz = G.dot(x) + s - h | |
ry = A.dot(x) - b if neq > 0 else np.zeros(0, dtype=Q.dtype) | |
mu = s.dot(z) / nineq | |
pri_resid = np.linalg.norm(ry) + np.linalg.norm(rz) | |
dual_resid = np.linalg.norm(rx) | |
resid = pri_resid + dual_resid + nineq * mu | |
d = z / s | |
if verbose: | |
print(("primal_res = {0:.5g}, dual_res = {1:.5g}, " + | |
"gap = {2:.5g}, kappa(d) = {3:.5g}").format( | |
pri_resid, dual_resid, mu, min(d) / max(d))) | |
# if (pri_resid < 5e-4 and dual_resid < 5e-4 and mu < 4e-4): | |
improved = (prev_resid is None) or (resid < prev_resid + 1e-6) | |
if not improved or resid < 1e-6: | |
return x, y, z | |
prev_resid = resid | |
factor_kkt(U_S, R, d) | |
dx_aff, ds_aff, dz_aff, dy_aff = \ | |
solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry) | |
# compute centering directions | |
alpha = min(min(get_step(z, dz_aff), get_step(s, ds_aff)), 1.0) | |
sig = (np.dot(s + alpha * ds_aff, z + | |
alpha * dz_aff) / (np.dot(s, z)))**3 | |
dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( | |
U_Q, d, G, A, U_S, | |
np.zeros(nz, dtype=Q.dtype), | |
(-mu * sig * np.ones(nineq, dtype=Q.dtype) + ds_aff * dz_aff) / s, | |
np.zeros(nineq, dtype=Q.dtype), | |
np.zeros(neq, dtype=Q.dtype)) | |
dx = dx_aff + dx_cor | |
ds = ds_aff + ds_cor | |
dz = dz_aff + dz_cor | |
dy = dy_aff + dy_cor if neq > 0 else None | |
alpha = min(1.0, 0.999 * min(get_step(s, ds), get_step(z, dz))) | |
dx_norm = np.linalg.norm(dx) | |
dz_norm = np.linalg.norm(dz) | |
if np.isnan(dx_norm) or dx_norm > 1e5 or dz_norm > 1e5: | |
# Overflow, return early | |
return x, y, z | |
x += alpha * dx | |
s += alpha * ds | |
z += alpha * dz | |
y = y + alpha * dy if neq > 0 else None | |
return x, y, z | |
def get_step(v, dv): | |
#I = dv < 1e-12 | |
I = dv < 0 | |
if np.any(I): | |
return np.min(-v[I] / dv[I]) | |
else: | |
return 1 | |
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False): | |
""" Solve KKT equations for the affine step""" | |
nineq, nz, neq, _ = get_sizes(G, A) | |
invQ_rx = scipy.linalg.cho_solve((U_Q, False), rx) | |
if neq > 0: | |
h = np.concatenate((A.dot(invQ_rx) - ry, G.dot(invQ_rx) + rs / d - rz), 0) | |
else: | |
h = G.dot(invQ_rx) + rs / d - rz | |
w = -scipy.linalg.cho_solve((U_S, False), h) | |
g1 = -rx - G.T.dot(w[neq:]) | |
if neq > 0: | |
g1 -= A.T.dot(w[:neq]) | |
g2 = -rs - w[neq:] | |
dx = scipy.linalg.cho_solve((U_Q, False), g1) | |
ds = g2 / d | |
dz = w[neq:] | |
dy = w[:neq] if neq > 0 else None | |
return dx, ds, dz, dy | |
def pre_factor_kkt(Q, G, A): | |
""" Perform all one-time factorizations and cache relevant matrix products""" | |
nineq, nz, neq, _ = get_sizes(G, A) | |
# S = [ A Q^{-1} A^T A Q^{-1} G^T ] | |
# [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] | |
(U_Q, _) = scipy.linalg.cho_factor(Q) | |
# partial cholesky of S matrix | |
U_S = np.zeros((neq + nineq, neq + nineq), dtype=Q.dtype) | |
G_invQ_GT = G @ scipy.linalg.cho_solve((U_Q,False), G.T) | |
R = G_invQ_GT | |
if neq > 0: | |
invQ_AT = scipy.linalg.cho_solve((U_Q,False), A.T) | |
A_invQ_AT = A @ invQ_AT | |
G_invQ_AT = G @ invQ_AT | |
(U11, _) = scipy.linalg.cho_factor(A_invQ_AT) | |
U12 = no.linalg.solve(U11.T, G_invQ_AT.T) | |
U_S[:neq, :neq] = U11 | |
U_S[:neq, neq:] = U12 | |
R -= U12.T @ U12 | |
return U_Q, U_S, R | |
def factor_kkt(U_S, R, d): | |
""" Factor the U22 block that we can only do after we know D. """ | |
nineq = R.shape[0] | |
(U22, _) = scipy.linalg.cho_factor(R + np.diag(1.0 / d)) | |
U_S[-nineq:, -nineq:] = U22 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment