Instantly share code, notes, and snippets.

# mblondel/projection_simplex.py

Last active May 10, 2023 04:29
Projection onto the simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """ License: BSD Author: Mathieu Blondel Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. For details and references, see the following paper: Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex Mathieu Blondel, Akinori Fujino, and Naonori Ueda. ICPR 2014. http://www.mblondel.org/publications/mblondel-icpr2014.pdf """ import numpy as np def projection_simplex_sort(v, z=1): n_features = v.shape u = np.sort(v)[::-1] cssv = np.cumsum(u) - z ind = np.arange(n_features) + 1 cond = u - cssv / ind > 0 rho = ind[cond][-1] theta = cssv[cond][-1] / float(rho) w = np.maximum(v - theta, 0) return w def projection_simplex_pivot(v, z=1, random_state=None): rs = np.random.RandomState(random_state) n_features = len(v) U = np.arange(n_features) s = 0 rho = 0 while len(U) > 0: G = [] L = [] k = U[rs.randint(0, len(U))] ds = v[k] for j in U: if v[j] >= v[k]: if j != k: ds += v[j] G.append(j) elif v[j] < v[k]: L.append(j) drho = len(G) + 1 if s + ds - (rho + drho) * v[k] < z: s += ds rho += drho U = L else: U = G theta = (s - z) / float(rho) return np.maximum(v - theta, 0) def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): func = lambda x: np.sum(np.maximum(v - x, 0)) - z lower = np.min(v) - z / len(v) upper = np.max(v) for it in range(max_iter): midpoint = (upper + lower) / 2.0 value = func(midpoint) if abs(value) <= tau: break if value <= 0: upper = midpoint else: lower = midpoint return np.maximum(v - midpoint, 0) if __name__ == '__main__': v = np.array([1.1, 0.2, 0.2]) z = 2 w = projection_simplex_sort(v, z) print(np.sum(w)) w = projection_simplex_pivot(v, z) print(np.sum(w)) w = projection_simplex_bisection(v, z) print(np.sum(w))

### joseortiz3 commented Apr 25, 2021 • edited

At least in Python 3.8, numpy version 1.20.2, after replacing `xrange` -> `range` and putting parentheses around `print()` for Py3 compatibility, I find this: for

```v = np.array([1+0.1,0+0.2,0+0.2])
z = 2```

the script prints

``````2
2.0
2.0
1.5
``````

Is this an issue with `projection_simplex_bisection`, or just a pathological example?

### mblondel commented May 15, 2021

Thanks @joseortiz3. I fixed the bug (there was a mistake in the bracketing interval).

### flbbb commented Nov 29, 2021 • edited

In case someone needs to project each column of a 2D array onto the simplex, here is an adapted code sample (Python 3.9.5):

```    def projection_simplex_sort_2d(v, z=1):
"""v array of shape (n_features, n_samples)."""
p, n = v.shape
u = np.sort(v, axis=0)[::-1, ...]
pi = np.cumsum(u, axis=0) - z
ind = (np.arange(p) + 1).reshape(-1, 1)
mask = (u - pi / ind) > 0
rho = p - 1 - np.argmax(mask[::-1, ...], axis=0)
theta = pi[tuple([rho, np.arange(n)])] / (rho + 1)
w = np.maximum(v - theta, 0)
return w```

### mblondel commented Nov 29, 2021

Hehe, I also wrote a vectorized version here https://gist.github.com/mblondel/c99e575a5207c76a99d714e8c6e08e89