-
-
Save feiaa/1cf89fae7c15e3bf4191fc232fef5694 to your computer and use it in GitHub Desktop.
Projection onto the simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
lower = 0 | |
upper = np.max(v) | |
current = np.inf | |
for it in xrange(max_iter): | |
if np.abs(current) / z < tau and current < 0: | |
break | |
theta = (upper + lower) / 2.0 | |
w = np.maximum(v - theta, 0) | |
current = np.sum(w) - z | |
if current <= 0: | |
upper = theta | |
else: | |
lower = theta | |
return w | |
if __name__ == '__main__': | |
rs = np.random.RandomState(0) | |
v = rs.rand(1000) | |
z = np.sum(v) * 0.5 | |
print z | |
w = projection_simplex_sort(v, z) | |
print np.sum(w) | |
w = projection_simplex_pivot(v, z) | |
print np.sum(w) | |
w = projection_simplex_bisection(v, z) | |
print np.sum(w) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment