Created
April 10, 2018 23:54
-
-
Save ahwillia/e0a54151c1d0ef7fd0bee138c2309091 to your computer and use it in GitHub Desktop.
Convolutive NMF by block coordinate descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from tqdm import trange | |
import matplotlib.pyplot as plt | |
# TODO: subclass np.ndarray? | |
class ShiftMatrix(object): | |
""" | |
Thin wrapper around a numpy matrix to support shifting along the second | |
axis and padding with zeros. | |
""" | |
def __init__(self, X, L): | |
""" | |
X : numpy matrix | |
L : int, largest shift | |
""" | |
# ShiftMatrix behaves like the original matrix | |
self.shape = X.shape | |
self.size = X.size | |
# Padded version of X | |
self.L = L | |
self.X = np.pad(X, ((0, 0), (L, L)), mode='constant') | |
def shift(self, l): | |
if np.abs(l) > self.L: | |
raise ValueError('requested too large of a shift.') | |
r = slice(self.L - l, self.L + self.shape[1] - l) | |
return self.X[:, r] | |
def assign(self, Xnew): | |
self.X[:, self.L:-self.L] = Xnew | |
class ConvNMF(object): | |
def __init__(self, n_components, maxlag, tol=1e-5, n_iter_max=100, | |
l1_W=0.0, l1_H=0.0): | |
""" | |
l1_W (float) : strength of sparsity penalty on W | |
l1_H (float) : strength of sparsity penalty on H | |
""" | |
self.n_components = n_components | |
self.maxlag = maxlag | |
self.W = None | |
self.H = None | |
self.tol = 1e-4 | |
self.n_iter_max = n_iter_max | |
self.l1_W = l1_W | |
self.l1_H = l1_H | |
self._shifts = np.arange(maxlag*2 + 1) - maxlag | |
self.loss_hist = None | |
def fit(self, data): | |
# check input | |
if (data < 0).any(): | |
raise ValueError('Negative values in data to fit') | |
data = ShiftMatrix(data, self.maxlag) | |
m, n = data.shape | |
# initialize W and H | |
self.W = np.random.rand(self.maxlag*2 + 1, m, self.n_components) | |
self.H = ShiftMatrix(np.random.rand(self.n_components, n), self.maxlag) | |
# optimize | |
converged, itr = False, 0 | |
# initial calculation of W gradient | |
loss_1, grad_W = self._compute_gW(data) | |
self.loss_hist = [loss_1] | |
for itr in trange(self.n_iter_max): | |
# Update H | |
# TODO: compute Lipshitz constant for optimal learning rate! | |
self.W = np.maximum(self.W - 0.00001*grad_W, 0) | |
# compute gradient of H | |
_, grad_H = self._compute_gH(data) | |
# Update W | |
# TODO: compute Lipshitz constant for optimal learning rate! | |
self.H.assign(np.maximum(self.H.shift(0) - 0.00001*grad_H, 0)) | |
# compute gradient of W | |
loss_2, grad_W = self._compute_gW(data) | |
self.loss_hist += [loss_2] | |
# check convergence | |
if (loss_1 - loss_2) < self.tol: | |
converged = True | |
break | |
# move to next iteration | |
else: | |
loss_1 = loss_2 | |
return self | |
def predict(self): | |
"""Return low-rank reconstruction of data. | |
""" | |
# check that W and H are fit | |
self._check_is_fitted() | |
W, H = self.W, self. H | |
# dimensions | |
m, n = W.shape[1], H.shape[1] | |
# preallocate result | |
result = np.zeros((m, n)) | |
# iterate over lags | |
for w, t in zip(W, self._shifts): | |
result += np.dot(w, H.shift(t)) | |
return result | |
def _check_is_fitted(self): | |
if self.W is None or self.H is None: | |
raise ValueError('This ConvNMF instance is not fitted yet.' | |
'Call \'fit\' with appropriate arguments ' | |
'before using this method.') | |
def _compute_loss(self, data): | |
"""Root Mean Squared Error | |
""" | |
resid = (self.predict() - data.shift(0)).ravel() | |
return np.sqrt(np.mean(np.dot(resid, resid))) | |
def _compute_gW(self, data): | |
# compute residuals | |
resid = self.predict() - data.shift(0) | |
# TODO: replace with broadcasting | |
Wgrad = np.empty(self.W.shape) | |
for l, t in enumerate(self._shifts): | |
Wgrad[l] = np.dot(resid, self.H.shift(t).T) | |
# compute loss | |
r = resid.ravel() | |
loss = np.sqrt(np.mean(np.dot(r, r))) | |
return loss, Wgrad | |
def _compute_gH(self, data): | |
# compute residuals | |
resid = self.predict() - data.shift(0) | |
# compute gradient | |
# TODO: speed up with broadcasting | |
Hgrad = np.zeros(self.H.shape) | |
for l, t in enumerate(self._shifts): | |
dh = np.dot(self.W[l].T, resid) | |
Hgrad += _shift(dh, -t) | |
# compute loss | |
r = resid.ravel() | |
loss = np.sqrt(np.mean(np.dot(r, r))) | |
return loss, Hgrad | |
def seq_nmf_data(N, T, L, K): | |
"""Creates synthetic dataset for conv NMF | |
Args | |
---- | |
N : number of neurons | |
T : number of timepoints | |
L : max sequence length | |
K : number of factors / rank | |
Returns | |
------- | |
data : N x T matrix | |
""" | |
# low-rank data | |
W, H = np.random.rand(N, K), np.random.rand(K, T) | |
W[W < .5] = 0 | |
H[H < .8] = 0 | |
lrd = np.dot(W, H) | |
# add a random shift to each row | |
lags = np.random.randint(0, L, size=N) | |
data = np.array([np.roll(row, l, axis=-1) for row, l in zip(lrd, lags)]) | |
# data = lrd | |
return data, W, H | |
def _shift(X, l): | |
"""shifts matrix X along second axis and zero pads | |
""" | |
if l < 0: | |
return np.pad(X, ((0, 0), (0, -l)), mode='constant')[:, -l:] | |
elif l > 0: | |
return np.pad(X, ((0, 0), (l, 0)), mode='constant')[:, :-l] | |
else: | |
return X | |
if __name__ == '__main__': | |
data, W, H = seq_nmf_data(100, 101, 10, 5) | |
losses = [] | |
for k in range(1, 10): | |
model = ConvNMF(k, 15).fit(data) | |
plt.plot(model.loss_hist) | |
losses.append(model.loss_hist[-1]) | |
plt.figure() | |
plt.plot(losses) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment