Skip to content

Instantly share code, notes, and snippets.

@ozansener
Created August 1, 2018 14:41
Show Gist options
  • Save ozansener/b5c6f6ed22ff23a6ff365c8d9b19030a to your computer and use it in GitHub Desktop.
Save ozansener/b5c6f6ed22ff23a6ff365c8d9b19030a to your computer and use it in GitHub Desktop.
def projection2simplex(y):
"""
Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
"""
m = len(y)
sorted_y = np.flip(np.sort(y), axis=0)
tmpsum = 0.0
tmax_f = (np.sum(y) - 1.0)/m
for i in range(m-1):
tmpsum+= sorted_y[i]
tmax = (tmpsum - 1)/ (i+1.0)
if tmax > sorted_y[i+1]:
tmax_f = tmax
break
return np.maximum(y - tmax_f, np.zeros(y.shape))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment