Skip to content

Instantly share code, notes, and snippets.

@msakai
Created February 10, 2019 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/9a5c31882a81ec83f60e046d2db4f1db to your computer and use it in GitHub Desktop.
Save msakai/9a5c31882a81ec83f60e046d2db4f1db to your computer and use it in GitHub Desktop.
# numpy/scipy translation of https://github.com/locuslab/qpth/blob/5485219028a7687b76107c8431625aaedfd7bc36/qpth/solvers/pdipm/single.py
import numpy as np
import scipy
import scipy.linalg
# TODO: Add more comments describing the math here.
# https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf
def get_sizes(G, A=None):
if A is None:
neq = 0
else:
new = A.shape[0]
return (G.shape[0], G.shape[1], neq, 1)
def forward(inputs_i, Q, G, A, b, h, U_Q, U_S, R, verbose=False):
"""
U_Q, U_S, R = pre_factor_kkt(Q, G, A)
"""
nineq, nz, neq, _ = get_sizes(G, A)
# find initial values
d = np.ones(nineq, dtype=Q.dtype)
nb = -b if b is not None else None
factor_kkt(U_S, R, d)
x, s, z, y = solve_kkt(
U_Q, d, G, A, U_S,
inputs_i, np.zeros(nineq, dtype=Q.dtype), -h, nb)
if np.min(s) < 0:
s -= np.min(s) - 1
if np.min(z) < 0:
z -= np.min(z) - 1
prev_resid = None
for i in range(20):
# affine scaling direction
rx = (A.T.dot(y) if neq > 0 else 0.) + G.T.dot(z) + Q.dot(x) + inputs_i
rs = z
rz = G.dot(x) + s - h
ry = A.dot(x) - b if neq > 0 else np.zeros(0, dtype=Q.dtype)
mu = s.dot(z) / nineq
pri_resid = np.linalg.norm(ry) + np.linalg.norm(rz)
dual_resid = np.linalg.norm(rx)
resid = pri_resid + dual_resid + nineq * mu
d = z / s
if verbose:
print(("primal_res = {0:.5g}, dual_res = {1:.5g}, " +
"gap = {2:.5g}, kappa(d) = {3:.5g}").format(
pri_resid, dual_resid, mu, min(d) / max(d)))
# if (pri_resid < 5e-4 and dual_resid < 5e-4 and mu < 4e-4):
improved = (prev_resid is None) or (resid < prev_resid + 1e-6)
if not improved or resid < 1e-6:
return x, y, z
prev_resid = resid
factor_kkt(U_S, R, d)
dx_aff, ds_aff, dz_aff, dy_aff = \
solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry)
# compute centering directions
alpha = min(min(get_step(z, dz_aff), get_step(s, ds_aff)), 1.0)
sig = (np.dot(s + alpha * ds_aff, z +
alpha * dz_aff) / (np.dot(s, z)))**3
dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
U_Q, d, G, A, U_S,
np.zeros(nz, dtype=Q.dtype),
(-mu * sig * np.ones(nineq, dtype=Q.dtype) + ds_aff * dz_aff) / s,
np.zeros(nineq, dtype=Q.dtype),
np.zeros(neq, dtype=Q.dtype))
dx = dx_aff + dx_cor
ds = ds_aff + ds_cor
dz = dz_aff + dz_cor
dy = dy_aff + dy_cor if neq > 0 else None
alpha = min(1.0, 0.999 * min(get_step(s, ds), get_step(z, dz)))
dx_norm = np.linalg.norm(dx)
dz_norm = np.linalg.norm(dz)
if np.isnan(dx_norm) or dx_norm > 1e5 or dz_norm > 1e5:
# Overflow, return early
return x, y, z
x += alpha * dx
s += alpha * ds
z += alpha * dz
y = y + alpha * dy if neq > 0 else None
return x, y, z
def get_step(v, dv):
#I = dv < 1e-12
I = dv < 0
if np.any(I):
return np.min(-v[I] / dv[I])
else:
return 1
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False):
""" Solve KKT equations for the affine step"""
nineq, nz, neq, _ = get_sizes(G, A)
invQ_rx = scipy.linalg.cho_solve((U_Q, False), rx)
if neq > 0:
h = np.concatenate((A.dot(invQ_rx) - ry, G.dot(invQ_rx) + rs / d - rz), 0)
else:
h = G.dot(invQ_rx) + rs / d - rz
w = -scipy.linalg.cho_solve((U_S, False), h)
g1 = -rx - G.T.dot(w[neq:])
if neq > 0:
g1 -= A.T.dot(w[:neq])
g2 = -rs - w[neq:]
dx = scipy.linalg.cho_solve((U_Q, False), g1)
ds = g2 / d
dz = w[neq:]
dy = w[:neq] if neq > 0 else None
return dx, ds, dz, dy
def pre_factor_kkt(Q, G, A):
""" Perform all one-time factorizations and cache relevant matrix products"""
nineq, nz, neq, _ = get_sizes(G, A)
# S = [ A Q^{-1} A^T A Q^{-1} G^T ]
# [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ]
(U_Q, _) = scipy.linalg.cho_factor(Q)
# partial cholesky of S matrix
U_S = np.zeros((neq + nineq, neq + nineq), dtype=Q.dtype)
G_invQ_GT = G @ scipy.linalg.cho_solve((U_Q,False), G.T)
R = G_invQ_GT
if neq > 0:
invQ_AT = scipy.linalg.cho_solve((U_Q,False), A.T)
A_invQ_AT = A @ invQ_AT
G_invQ_AT = G @ invQ_AT
(U11, _) = scipy.linalg.cho_factor(A_invQ_AT)
U12 = no.linalg.solve(U11.T, G_invQ_AT.T)
U_S[:neq, :neq] = U11
U_S[:neq, neq:] = U12
R -= U12.T @ U12
return U_Q, U_S, R
def factor_kkt(U_S, R, d):
""" Factor the U22 block that we can only do after we know D. """
nineq = R.shape[0]
(U22, _) = scipy.linalg.cho_factor(R + np.diag(1.0 / d))
U_S[-nineq:, -nineq:] = U22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment