Skip to content

Instantly share code, notes, and snippets.

@tonyduan
Last active March 12, 2023 11:05
Show Gist options
  • Save tonyduan/1329998205d88c566588e57e3e2c0c55 to your computer and use it in GitHub Desktop.
Save tonyduan/1329998205d88c566588e57e3e2c0c55 to your computer and use it in GitHub Desktop.
Compute Euclidean projections onto the L1-ball in PyTorch.
def project_onto_l1_ball(x, eps):
"""
Compute Euclidean projection onto the L1 ball for a batch.
min ||x - u||_2 s.t. ||u||_1 <= eps
Inspired by the corresponding numpy version by Adrien Gaidon.
Parameters
----------
x: (batch_size, *) torch array
batch of arbitrary-size tensors to project, possibly on GPU
eps: float
radius of l-1 ball to project onto
Returns
-------
u: (batch_size, *) torch array
batch of projected tensors, reshaped to match the original
Notes
-----
The complexity of this algorithm is in O(dlogd) as it involves sorting x.
References
----------
[1] Efficient Projections onto the l1-Ball for Learning in High Dimensions
John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra.
International Conference on Machine Learning (ICML 2008)
"""
original_shape = x.shape
x = x.view(x.shape[0], -1)
mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1)
mu, _ = torch.sort(torch.abs(x), dim=1, descending=True)
cumsum = torch.cumsum(mu, dim=1)
arange = torch.arange(1, x.shape[1] + 1, device=x.device)
rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1)
theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho
proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0)
x = mask * x + (1 - mask) * proj * torch.sign(x)
return x.view(original_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment