Projection onto the simplex
""" | |
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
lower = 0 | |
upper = np.max(v) | |
current = np.inf | |
for it in xrange(max_iter): | |
if np.abs(current) / z < tau and current < 0: | |
break | |
theta = (upper + lower) / 2.0 | |
w = np.maximum(v - theta, 0) | |
current = np.sum(w) - z | |
if current <= 0: | |
upper = theta | |
else: | |
lower = theta | |
return w | |
if __name__ == '__main__': | |
rs = np.random.RandomState(0) | |
v = rs.rand(1000) | |
z = np.sum(v) * 0.5 | |
print z | |
w = projection_simplex_sort(v, z) | |
print np.sum(w) | |
w = projection_simplex_pivot(v, z) | |
print np.sum(w) | |
w = projection_simplex_bisection(v, z) | |
print np.sum(w) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment