Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# mblondel/projection_simplex.py

Last active May 15, 2021
Projection onto the simplex
 """ Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. For details and references, see the following paper: Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex Mathieu Blondel, Akinori Fujino, and Naonori Ueda. ICPR 2014. http://www.mblondel.org/publications/mblondel-icpr2014.pdf """ import numpy as np def projection_simplex_sort(v, z=1): n_features = v.shape u = np.sort(v)[::-1] cssv = np.cumsum(u) - z ind = np.arange(n_features) + 1 cond = u - cssv / ind > 0 rho = ind[cond][-1] theta = cssv[cond][-1] / float(rho) w = np.maximum(v - theta, 0) return w def projection_simplex_pivot(v, z=1, random_state=None): rs = np.random.RandomState(random_state) n_features = len(v) U = np.arange(n_features) s = 0 rho = 0 while len(U) > 0: G = [] L = [] k = U[rs.randint(0, len(U))] ds = v[k] for j in U: if v[j] >= v[k]: if j != k: ds += v[j] G.append(j) elif v[j] < v[k]: L.append(j) drho = len(G) + 1 if s + ds - (rho + drho) * v[k] < z: s += ds rho += drho U = L else: U = G theta = (s - z) / float(rho) return np.maximum(v - theta, 0) def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): func = lambda x: np.sum(np.maximum(v - x, 0)) - z lower = np.min(v) - z / len(v) upper = np.max(v) for it in range(max_iter): midpoint = (upper + lower) / 2.0 value = func(midpoint) if abs(value) <= tau: break if value <= 0: upper = midpoint else: lower = midpoint return np.maximum(v - midpoint, 0) if __name__ == '__main__': v = np.array([1.1, 0.2, 0.2]) z = 2 w = projection_simplex_sort(v, z) print(np.sum(w)) w = projection_simplex_pivot(v, z) print(np.sum(w)) w = projection_simplex_bisection(v, z) print(np.sum(w))

### joseortiz3 commented Apr 25, 2021 • edited

 At least in Python 3.8, numpy version 1.20.2, after replacing `xrange` -> `range` and putting parentheses around `print()` for Py3 compatibility, I find this: for ```v = np.array([1+0.1,0+0.2,0+0.2]) z = 2``` the script prints ``````2 2.0 2.0 1.5 `````` Is this an issue with `projection_simplex_bisection`, or just a pathological example?

### mblondel commented May 15, 2021

 Thanks @joseortiz3. I fixed the bug (there was a mistake in the bracketing interval).
to join this conversation on GitHub. Already have an account? Sign in to comment