Created
November 6, 2016 06:48
-
-
Save mblondel/5105786d740693a6996bcb8e482c7083 to your computer and use it in GitHub Desktop.
Efficient implementation of FISTA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Efficient implementation of FISTA. | |
""" | |
# Author: Mathieu Blondel | |
# License: BSD 3 clause | |
import numpy as np | |
def fista(sfunc, nsfunc, x0, max_iter=500, max_linesearch=20, eta=2.0, tol=1e-3, | |
verbose=0): | |
y = x0.copy() | |
x = y | |
L = 1.0 | |
t = 1.0 | |
for it in xrange(max_iter): | |
f_old, grad = sfunc(y, True) | |
for ls in xrange(max_linesearch): | |
y_proj, g = nsfunc(y - grad / L, L) | |
diff = (y_proj - y).ravel() | |
sqdist = np.dot(diff, diff) | |
dist = np.sqrt(sqdist) | |
f = sfunc(y_proj) | |
F = f + g | |
Q = f_old + np.dot(diff, grad.ravel()) + 0.5 * L * sqdist + g | |
if F <= Q: | |
break | |
L *= eta | |
if ls == max_linesearch - 1 and verbose: | |
print("Line search did not converge.") | |
if verbose: | |
print("%d. %f" % (it + 1, dist)) | |
if dist <= tol: | |
if verbose: | |
print("Converged.") | |
break | |
x_next = y_proj | |
t_next = (1 + np.sqrt(1 + 4 * t ** 2)) / 2. | |
y = x_next + (t-1) / t_next * (x_next - x) | |
t = t_next | |
x = x_next | |
return y_proj | |
def test_l1(verbose): | |
n_samples = 500 | |
n_features = 1000 | |
rng = np.random.RandomState(0) | |
X = rng.randn(n_samples, n_features) | |
w = rng.randn(n_features) | |
y = np.dot(X, w) | |
y += rng.randn(n_samples) * np.std(y) | |
lam = 1e2 | |
def sfunc(w, grad=False): | |
y_pred = np.dot(X, w) | |
diff = y_pred - y | |
obj = 0.5 * np.dot(diff, diff) | |
if not grad: | |
return obj | |
grad = np.dot(X.T, diff) | |
return obj, grad | |
def nsfunc(w, L): | |
w = np.sign(w) * np.maximum(np.abs(w) - lam / L, 0) | |
val = lam * np.sum(np.abs(w)) | |
return w, val | |
w0 = np.zeros_like(w) | |
w_fit = fista(sfunc, nsfunc, w0, verbose=verbose) | |
print np.sum((w - w_fit) ** 2) | |
def test_nuclear(verbose): | |
from scipy.linalg import svd | |
def _predict(X, W): | |
XW = np.dot(X, W) | |
return np.einsum("ij,ji->i", X, XW.T) | |
n_samples = 500 | |
n_features = 50 | |
rng = np.random.RandomState(0) | |
X = rng.randn(n_samples, n_features) | |
W = rng.randn(n_features, n_features) | |
W = 0.5 * (W + W.T) | |
y = _predict(X, W) | |
y += rng.randn(n_samples) * np.std(y) | |
lam = 1e2 | |
def sfunc(W, grad=False): | |
y_pred = _predict(X, W) | |
diff = y_pred - y | |
obj = 0.5 * np.dot(diff, diff) | |
if not grad: | |
return obj | |
grad = np.dot(X.T * diff, X) | |
return obj, grad | |
def nsfunc(W, L): | |
U, s, V = svd(W, full_matrices=False) | |
s = np.maximum(s - lam / L, 0) | |
U *= s | |
W = np.dot(U, V) | |
val = np.sum(s) | |
return W, val | |
W0 = np.zeros_like(W) | |
W_fit = fista(sfunc, nsfunc, W0, verbose=verbose) | |
print np.sum((W - W_fit) ** 2) | |
if __name__ == '__main__': | |
import sys | |
test_l1(verbose=0) | |
test_nuclear(verbose=0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment