Efficient implementation of FISTA
""" | |
Efficient implementation of FISTA. | |
""" | |
# Author: Mathieu Blondel | |
# License: BSD 3 clause | |
import numpy as np | |
def fista(sfunc, nsfunc, x0, max_iter=500, max_linesearch=20, eta=2.0, tol=1e-3, | |
verbose=0): | |
y = x0.copy() | |
x = y | |
L = 1.0 | |
t = 1.0 | |
for it in xrange(max_iter): | |
f_old, grad = sfunc(y, True) | |
for ls in xrange(max_linesearch): | |
y_proj, g = nsfunc(y - grad / L, L) | |
diff = (y_proj - y).ravel() | |
sqdist = np.dot(diff, diff) | |
dist = np.sqrt(sqdist) | |
f = sfunc(y_proj) | |
F = f + g | |
Q = f_old + np.dot(diff, grad.ravel()) + 0.5 * L * sqdist + g | |
if F <= Q: | |
break | |
L *= eta | |
if ls == max_linesearch - 1 and verbose: | |
print("Line search did not converge.") | |
if verbose: | |
print("%d. %f" % (it + 1, dist)) | |
if dist <= tol: | |
if verbose: | |
print("Converged.") | |
break | |
x_next = y_proj | |
t_next = (1 + np.sqrt(1 + 4 * t ** 2)) / 2. | |
y = x_next + (t-1) / t_next * (x_next - x) | |
t = t_next | |
x = x_next | |
return y_proj | |
def test_l1(verbose): | |
n_samples = 500 | |
n_features = 1000 | |
rng = np.random.RandomState(0) | |
X = rng.randn(n_samples, n_features) | |
w = rng.randn(n_features) | |
y = np.dot(X, w) | |
y += rng.randn(n_samples) * np.std(y) | |
lam = 1e2 | |
def sfunc(w, grad=False): | |
y_pred = np.dot(X, w) | |
diff = y_pred - y | |
obj = 0.5 * np.dot(diff, diff) | |
if not grad: | |
return obj | |
grad = np.dot(X.T, diff) | |
return obj, grad | |
def nsfunc(w, L): | |
w = np.sign(w) * np.maximum(np.abs(w) - lam / L, 0) | |
val = lam * np.sum(np.abs(w)) | |
return w, val | |
w0 = np.zeros_like(w) | |
w_fit = fista(sfunc, nsfunc, w0, verbose=verbose) | |
print np.sum((w - w_fit) ** 2) | |
def test_nuclear(verbose): | |
from scipy.linalg import svd | |
def _predict(X, W): | |
XW = np.dot(X, W) | |
return np.einsum("ij,ji->i", X, XW.T) | |
n_samples = 500 | |
n_features = 50 | |
rng = np.random.RandomState(0) | |
X = rng.randn(n_samples, n_features) | |
W = rng.randn(n_features, n_features) | |
W = 0.5 * (W + W.T) | |
y = _predict(X, W) | |
y += rng.randn(n_samples) * np.std(y) | |
lam = 1e2 | |
def sfunc(W, grad=False): | |
y_pred = _predict(X, W) | |
diff = y_pred - y | |
obj = 0.5 * np.dot(diff, diff) | |
if not grad: | |
return obj | |
grad = np.dot(X.T * diff, X) | |
return obj, grad | |
def nsfunc(W, L): | |
U, s, V = svd(W, full_matrices=False) | |
s = np.maximum(s - lam / L, 0) | |
U *= s | |
W = np.dot(U, V) | |
val = np.sum(s) | |
return W, val | |
W0 = np.zeros_like(W) | |
W_fit = fista(sfunc, nsfunc, W0, verbose=verbose) | |
print np.sum((W - W_fit) ** 2) | |
if __name__ == '__main__': | |
import sys | |
test_l1(verbose=0) | |
test_nuclear(verbose=0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment