Skip to content

Instantly share code, notes, and snippets.

@mblondel
Created November 6, 2016 06:48
Show Gist options
  • Save mblondel/5105786d740693a6996bcb8e482c7083 to your computer and use it in GitHub Desktop.
Save mblondel/5105786d740693a6996bcb8e482c7083 to your computer and use it in GitHub Desktop.
Efficient implementation of FISTA
"""
Efficient implementation of FISTA.
"""
# Author: Mathieu Blondel
# License: BSD 3 clause
import numpy as np
def fista(sfunc, nsfunc, x0, max_iter=500, max_linesearch=20, eta=2.0, tol=1e-3,
verbose=0):
y = x0.copy()
x = y
L = 1.0
t = 1.0
for it in xrange(max_iter):
f_old, grad = sfunc(y, True)
for ls in xrange(max_linesearch):
y_proj, g = nsfunc(y - grad / L, L)
diff = (y_proj - y).ravel()
sqdist = np.dot(diff, diff)
dist = np.sqrt(sqdist)
f = sfunc(y_proj)
F = f + g
Q = f_old + np.dot(diff, grad.ravel()) + 0.5 * L * sqdist + g
if F <= Q:
break
L *= eta
if ls == max_linesearch - 1 and verbose:
print("Line search did not converge.")
if verbose:
print("%d. %f" % (it + 1, dist))
if dist <= tol:
if verbose:
print("Converged.")
break
x_next = y_proj
t_next = (1 + np.sqrt(1 + 4 * t ** 2)) / 2.
y = x_next + (t-1) / t_next * (x_next - x)
t = t_next
x = x_next
return y_proj
def test_l1(verbose):
n_samples = 500
n_features = 1000
rng = np.random.RandomState(0)
X = rng.randn(n_samples, n_features)
w = rng.randn(n_features)
y = np.dot(X, w)
y += rng.randn(n_samples) * np.std(y)
lam = 1e2
def sfunc(w, grad=False):
y_pred = np.dot(X, w)
diff = y_pred - y
obj = 0.5 * np.dot(diff, diff)
if not grad:
return obj
grad = np.dot(X.T, diff)
return obj, grad
def nsfunc(w, L):
w = np.sign(w) * np.maximum(np.abs(w) - lam / L, 0)
val = lam * np.sum(np.abs(w))
return w, val
w0 = np.zeros_like(w)
w_fit = fista(sfunc, nsfunc, w0, verbose=verbose)
print np.sum((w - w_fit) ** 2)
def test_nuclear(verbose):
from scipy.linalg import svd
def _predict(X, W):
XW = np.dot(X, W)
return np.einsum("ij,ji->i", X, XW.T)
n_samples = 500
n_features = 50
rng = np.random.RandomState(0)
X = rng.randn(n_samples, n_features)
W = rng.randn(n_features, n_features)
W = 0.5 * (W + W.T)
y = _predict(X, W)
y += rng.randn(n_samples) * np.std(y)
lam = 1e2
def sfunc(W, grad=False):
y_pred = _predict(X, W)
diff = y_pred - y
obj = 0.5 * np.dot(diff, diff)
if not grad:
return obj
grad = np.dot(X.T * diff, X)
return obj, grad
def nsfunc(W, L):
U, s, V = svd(W, full_matrices=False)
s = np.maximum(s - lam / L, 0)
U *= s
W = np.dot(U, V)
val = np.sum(s)
return W, val
W0 = np.zeros_like(W)
W_fit = fista(sfunc, nsfunc, W0, verbose=verbose)
print np.sum((W - W_fit) ** 2)
if __name__ == '__main__':
import sys
test_l1(verbose=0)
print
test_nuclear(verbose=0)
print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment