Last active
September 12, 2024 21:44
-
-
Save mblondel/6f3b7aaad90606b98f71 to your computer and use it in GitHub Desktop.
Projection onto the simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
License: BSD | |
Author: Mathieu Blondel | |
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
func = lambda x: np.sum(np.maximum(v - x, 0)) - z | |
lower = np.min(v) - z / len(v) | |
upper = np.max(v) | |
for it in range(max_iter): | |
midpoint = (upper + lower) / 2.0 | |
value = func(midpoint) | |
if abs(value) <= tau: | |
break | |
if value <= 0: | |
upper = midpoint | |
else: | |
lower = midpoint | |
return np.maximum(v - midpoint, 0) | |
if __name__ == '__main__': | |
v = np.array([1.1, 0.2, 0.2]) | |
z = 2 | |
w = projection_simplex_sort(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_pivot(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_bisection(v, z) | |
print(np.sum(w)) |
Thanks @joseortiz3. I fixed the bug (there was a mistake in the bracketing interval).
In case someone needs to project each column of a 2D array onto the simplex, here is an adapted code sample (Python 3.9.5):
def projection_simplex_sort_2d(v, z=1):
"""v array of shape (n_features, n_samples)."""
p, n = v.shape
u = np.sort(v, axis=0)[::-1, ...]
pi = np.cumsum(u, axis=0) - z
ind = (np.arange(p) + 1).reshape(-1, 1)
mask = (u - pi / ind) > 0
rho = p - 1 - np.argmax(mask[::-1, ...], axis=0)
theta = pi[tuple([rho, np.arange(n)])] / (rho + 1)
w = np.maximum(v - theta, 0)
return w
Hehe, I also wrote a vectorized version here https://gist.github.com/mblondel/c99e575a5207c76a99d714e8c6e08e89
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
At least in Python 3.8, numpy version 1.20.2, after replacing
xrange
->range
and putting parentheses aroundprint()
for Py3 compatibility, I find this: forthe script prints
Is this an issue with
projection_simplex_bisection
, or just a pathological example?