Last active
August 29, 2015 14:26
-
-
Save TomDLT/7655df7b365f6df8fa85 to your computer and use it in GitHub Desktop.
Greedy coordinate descent for NMF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Tom Dupre la Tour | |
# License: BSD 3 clause | |
Greedy Coordinate Descent for Non-Negative Matrix Factorization | |
in scikit-learn. | |
To change the Coordinate Descent into a Greedy Coordinate Descent, | |
change the call to | |
_update_cdnmf_fast(W, HHt, XHt, shuffle, seed) (in nmf.py) | |
into | |
_update_greedy_(W, HHt, XHt, shuffle, seed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cython: cdivision=True | |
# cython: boundscheck=False | |
# cython: wraparound=False | |
# Author: Tom Dupre la Tour | |
# License: BSD 3 clause | |
import numpy as np | |
cimport cython | |
cimport numpy as np | |
from libc.math cimport fabs, fmax, fmin | |
def _update_greedy_cdnmf_fast(double[:, ::1] W, double[:, :] HHt, | |
double[:, :] G, bint shuffle, int seed, | |
int max_inner, double tol): | |
cdef double violation = 0. | |
cdef double pg | |
cdef int n_samples = W.shape[0] # n_features for H update | |
cdef int n_components = W.shape[1] | |
cdef int qi | |
cdef int j, i, t, r | |
cdef double s = 0. | |
cdef double p_init = 0. | |
cdef double Di_max, Dir | |
cdef int[:] q = np.zeros(n_samples, dtype=np.int32) | |
cdef double[:, ::1] S = np.zeros((n_samples, n_components)) | |
cdef double[:, ::1] D = np.zeros((n_samples, n_components)) | |
cdef np.ndarray[long, ndim=1] permutation_array | |
cdef long* permutation = NULL | |
if shuffle: | |
rng = np.random.RandomState(seed) | |
permutation_array = rng.permutation(n_samples) | |
else: | |
permutation_array = np.arange(n_samples) | |
permutation = <long*> permutation_array.data | |
# compute S, D and p_init | |
with nogil: | |
for j in range(n_samples): | |
i = permutation[j] | |
Di_max = 0. | |
q[i] = 0 | |
for r in range(n_components): | |
# Step amplitude | |
if HHt[r, r] != 0: | |
S[i, r] = fmax(W[i, r] - G[i, r] / HHt[r, r], 0.) - W[i, r] | |
else: | |
S[i, r] = 0. | |
# Loss Difference | |
D[i, r] = -(G[i, r] + HHt[r, r] / 2. * S[i, r]) * S[i, r] | |
# find q[i] = argmax_r(D[i, r]) | |
if D[i, r] > Di_max: | |
q[i] = r | |
Di_max = D[i, r] | |
# find p_init = max(D) | |
if Di_max > p_init: | |
p_init = Di_max | |
if p_init == 0.: | |
return 0. | |
with nogil: | |
for i in range(n_samples): | |
qi = q[i] | |
Di_max = D[i, qi] | |
for t in range(max_inner): | |
if Di_max < tol * p_init: | |
break | |
# projected gradient for violation | |
pg = fmin(0, G[i, qi]) if W[i, qi] == 0 else G[i, qi] | |
violation += fabs(pg) | |
s = S[i, qi] | |
W[i, qi] += s | |
for r in range(n_components): | |
G[i, r] += s * HHt[qi, r] | |
for r in range(n_components): | |
if HHt[r, r] != 0: | |
S[i, r] = (fmax(W[i, r] - G[i, r] / HHt[r, r], 0) | |
- W[i, r]) | |
else: | |
S[i, r] = 0. | |
Dir = -(G[i, r] + HHt[r, r] / 2. * S[i, r]) * S[i, r] | |
# find qi = argmax_r(D[i, r]) | |
if r == 0 or Dir > Di_max: | |
qi = r | |
Di_max = Dir | |
return violation | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Tom Dupre la Tour | |
# License: BSD 3 clause | |
from greedy_cd_nmf_fast import _update_greedy_cdnmf_fast | |
# [...] | |
def _update_coordinate_descent(X, W, Ht, alpha, l1_ratio, shuffle, random_state): | |
# [...] | |
if greedy: | |
return _update_greedy_(W, HHt, XHt, shuffle, seed) | |
else: | |
return _update_cdnmf_fast(W, HHt, XHt, shuffle, seed) | |
def _update_greedy_(W, HHt, XHt, shuffle, seed): | |
tol_greedy = 0.001 | |
G = fast_dot(W, HHt) - XHt | |
n_components = HHt.shape[1] | |
max_inner = n_components ** 2 | |
return _update_greedy_cdnmf_fast(W, HHt, G, shuffle, seed, | |
max_inner, tol_greedy) | |
# [...] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment