Skip to content

Instantly share code, notes, and snippets.

@mblondel
Last active April 12, 2023 15:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mblondel/32b1906efca9636068d37c663b16b84f to your computer and use it in GitHub Desktop.
Save mblondel/32b1906efca9636068d37c663b16b84f to your computer and use it in GitHub Desktop.
# Author: Mathieu Blondel
# License: BSD
import numpy as np
def find_r(w,k):
d = w.shape[0]
beta= np.r_[np.Inf,np.sort(np.abs(w))[::-1]]
tmp = np.sum(beta[k:d+1])
for r in range(0,k): # from r = 0 to k-1
if r == k-1:
break
if (beta[k-r-1] > tmp / (r+1)) and tmp / (r+1) >= beta[k-r]:
break
else:
tmp += beta[k-r-1]
return r, tmp, beta[1:]
def k_support_norm(w, k, squared=False):
r, tmp, beta = find_r(w, k)
sqnorm = np.sum(beta[0:k-r-1] ** 2)
sqnorm += (tmp ** 2) / (r+1)
if squared:
return 0.5 * sqnorm
else:
return np.sqrt(sqnorm)
# Naive implementation with two loops.
def k_support_norm2(w, k, squared=False):
d = len(w)
ind = np.argsort(np.abs(w))[::-1]
beta = np.abs(w[ind])
beta = np.r_[np.Inf, beta]
for r in range(0, k): # from r = 0 to k-1
tmp = 0
for i in range(k-r, d+1): # from i = k-r
tmp += beta[i]
if beta[k-r-1] > tmp / (r+1) and tmp / (r+1) >= beta[k-r]:
break
sqnorm = np.sum(beta[1:k-r] ** 2)
sqnorm += (tmp ** 2) / (r+1)
if squared:
return 0.5 * sqnorm
else:
return np.sqrt(sqnorm)
def dual_k_support_norm(w, k, squared=False):
ind = np.argsort(np.abs(w))[::-1]
sqnorm = np.sum(w[ind][:k] ** 2)
if squared:
return 0.5 * sqnorm
else:
return np.sqrt(sqnorm)
if __name__ == '__main__':
rng = np.random.RandomState(None)
a = rng.randn(10)
k = 3
# Non-sparse vector
print(k_support_norm(a, k))
print(k_support_norm2(a, k))
print(dual_k_support_norm(a, k))
# k-sparse vector
ind = np.argsort(np.abs(a))
a[ind[k:]] = 0
print(k_support_norm(a, k))
print(k_support_norm2(a, k))
print(dual_k_support_norm(a, k))
import matplotlib.pyplot as plt
xs = np.linspace(-2, 2, 100)
plt.figure()
values = [k_support_norm(np.array([x, 0.5, -0.7]), k=2, squared=True) for x in xs]
plt.plot(xs, values)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment