Skip to content

Instantly share code, notes, and snippets.

@nmayorov
Created February 1, 2018 10:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nmayorov/69a6794ffe8be42903349d83232408c8 to your computer and use it in GitHub Desktop.
Save nmayorov/69a6794ffe8be42903349d83232408c8 to your computer and use it in GitHub Desktop.
Canonical constraint implementation for scipy.optimize
class CanonicalConstraint(object):
def __init__(self, n_eq, n_ineq, fun, jac, hess, keep_feasible):
self.n_eq = n_eq
self.n_ineq = n_ineq
self.fun = fun
self.jac = jac
self.hess = hess
self.keep_feasible = keep_feasible
@classmethod
def from_user_constraint(cls, cfun, bounds, keep_feasible):
lb, ub = bounds
n_bounds = lb.shape[0]
if np.all(lb == ub):
fun, jac, hess = cls._equal_to_canonical(cfun, lb)
n_eq = n_bounds
n_ineq = 0
elif np.all(lb == -np.inf):
fun, jac, hess = cls._less_to_canonical(cfun, ub)
n_eq = 0
n_ineq = n_bounds
elif np.all(ub == np.inf):
fun, jac, hess = cls._greater_to_canonical(cfun, lb)
n_eq = 0
n_ineq = n_bounds
else:
(fun, jac, hess, keep_feasible, n_eq, n_ineq) \
= cls._interval_to_canonical(cfun, lb, ub, keep_feasible)
return cls(n_eq, n_ineq, fun, jac, hess, keep_feasible)
@classmethod
def concatenate(cls, canonical_constraints, sparse_jacobian):
def fun(x):
eq_all = []
ineq_all = []
for c in canonical_constraints:
eq, ineq = c.fun(x)
eq_all.append(eq)
ineq_all.append(ineq)
return np.hstack(eq_all), np.hstack(ineq_all)
if sparse_jacobian:
vstack = sps.vstack
else:
vstack = np.vstack
def jac(x):
eq_all = []
ineq_all = []
for c in canonical_constraints:
eq, ineq = c.jac(x)
eq_all.append(eq)
ineq_all.append(ineq)
return vstack(eq_all), vstack(ineq_all)
def hess(x, v_eq, v_ineq):
hess_all = []
index_eq = 0
index_ineq = 0
for c in canonical_constraints:
vc_eq = v_eq[index_eq:index_eq + c.n_eq]
vc_ineq = v_ineq[index_ineq:index_ineq + c.n_ineq]
hess_all.append(c.hess(x, vc_eq, vc_ineq))
index_eq += c.n_eq
index_ineq += c.n_ineq
result = np.zeros_like(x)
for h in hess_all:
result += h.dot(x)
return result
n_eq = sum(c.n_eq for c in canonical_constraints)
n_ineq = sum(c.n_ineq for c in canonical_constraints)
keep_feasible = np.hstack((
c.keep_feasible for c in canonical_constraints))
return cls(n_eq, n_ineq, fun, jac, hess, keep_feasible)
@staticmethod
def _equal_to_canonical(cfun, value):
empty = np.zeros(0)
def fun(x):
return cfun.fun(x) - value, empty
def jac(x):
return cfun.jac(x), empty
def hess(x, v_eq, v_ineq):
return cfun.hess(x, v_eq)
return fun, jac, hess
@staticmethod
def _less_to_canonical(cfun, ub):
empty = np.empty(0)
def fun(x):
return empty, cfun.fun(x) - ub
def jac(x):
return empty, cfun.jac(x)
def hess(x, v_eq, v_ineq):
return cfun.hess(x, v_ineq)
return fun, jac, hess
@staticmethod
def _greater_to_canonical(cfun, lb):
empty = np.empty(0)
def fun(x):
return empty, lb - cfun.fun(x)
def jac(x):
return empty, -cfun.jac(x)
def hess(x, v_eq, v_ineq):
return -cfun.hess(x, v_ineq)
return fun, jac, hess
@staticmethod
def _interval_to_canonical(cfun, lb, ub, keep_feasible):
equal = lb == ub
less = lb == -np.inf
greater = ub == np.inf
interval = ~equal & ~less & ~greater
equal = np.nonzero(equal)[0]
less = np.nonzero(less)[0]
greater = np.nonzero(greater)[0]
interval = np.nonzero(interval)[0]
n_equal = equal.shape[0]
n_less = equal.shape[0]
n_greater = greater.shape[0]
n_interval = interval.shape[0]
keep_feasible = np.hstack((keep_feasible[equal],
keep_feasible[less],
keep_feasible[greater],
keep_feasible[interval],
keep_feasible[interval]))
def fun(x):
f = cfun.fun(x)
eq = f[equal] - lb[equal]
le = f[less] - ub[less]
ge = lb[greater] - f[greater]
il = f[interval] - ub[interval]
ig = lb[interval] - f[interval]
return eq, np.hstack((le, ge, il, ig))
def jac(x):
J = cfun.jac(x)
eq = J[equal]
le = J[less]
ge = -J[greater]
il = J[interval]
ig = -il
if sps.issparse(J):
ineq = sps.vstack((le, ge, il, ig))
else:
ineq = np.vstack((le, ge, il, ig))
return eq, ineq
def hess(x, v_eq, v_ineq):
v = np.zeros_like(lb)
v_l = v_ineq[:n_less]
v_g = v_ineq[n_less:n_less + n_greater]
v_il = v_ineq[-2 * n_interval:-n_interval]
v_ig = v_ineq[-n_interval:]
v[equal] = v_eq
v[less] = v_l
v[greater] = -v_g
v[interval] += v_il
v[interval] -= v_ig
return cfun.hess(x, v)
return (fun, jac, hess, keep_feasible, n_equal,
n_less + n_greater + 2 * n_interval)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment