Skip to content

Instantly share code, notes, and snippets.

@EdwardRaff
Forked from mblondel/projection_simplex.py
Last active August 29, 2015 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EdwardRaff/f4f4cf0c927c2addfb39 to your computer and use it in GitHub Desktop.
Save EdwardRaff/f4f4cf0c927c2addfb39 to your computer and use it in GitHub Desktop.
"""
Implements four algorithms for projecting a vector onto the simplex: sort, pivot, bisection, and brent.
For details and references, see the following paper:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
http://www.mblondel.org/publications/mblondel-icpr2014.pdf
"""
import numpy as np
from scipy.optimize import brentq
def projection_simplex_sort(v, z=1):
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = np.maximum(v - theta, 0)
return w
def projection_simplex_pivot(v, z=1, random_state=None):
rs = np.random.RandomState(random_state)
n_features = len(v)
U = np.arange(n_features)
s = 0
rho = 0
while len(U) > 0:
G = []
L = []
k = U[rs.randint(0, len(U))]
ds = v[k]
for j in U:
if v[j] >= v[k]:
if j != k:
ds += v[j]
G.append(j)
elif v[j] < v[k]:
L.append(j)
drho = len(G) + 1
if s + ds - (rho + drho) * v[k] < z:
s += ds
rho += drho
U = L
else:
U = G
theta = (s - z) / float(rho)
return np.maximum(v - theta, 0)
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):
lower = 0
upper = np.max(v)
current = np.inf
iter = 0
for it in xrange(max_iter):
if np.abs(current) / z < tau and current < 0:
break
theta = (upper + lower) / 2.0
w = np.maximum(v - theta, 0)
current = np.sum(w) - z
if current <= 0:
upper = theta
else:
lower = theta
iter+=1
print "bisection took ", str(iter)
return w
def projection_simplex_brent(v, z=1, tau=1e-9):
lower = 0
upper = np.max(v)
def minFunc(theta):
return np.sum(np.maximum(v - theta, 0.0))-z
x0, r = brentq(minFunc, lower, upper, xtol=tau, full_output = True)
print "brent took ", r.iterations, " iterations and ", r.function_calls, " function calls"
return np.maximum(v - x0, 0)
if __name__ == '__main__':
rs = np.random.RandomState(0)
v = rs.rand(1000)
z = np.sum(v) * 0.5
print z
w = projection_simplex_sort(v, z)
print np.sum(w)
w = projection_simplex_pivot(v, z)
print np.sum(w)
w = projection_simplex_bisection(v, z)
print np.sum(w)
w = projection_simplex_brent(v, z)
print np.sum(w)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment