Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Created April 10, 2018 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/e0a54151c1d0ef7fd0bee138c2309091 to your computer and use it in GitHub Desktop.
Save ahwillia/e0a54151c1d0ef7fd0bee138c2309091 to your computer and use it in GitHub Desktop.
Convolutive NMF by block coordinate descent
import numpy as np
from tqdm import trange
import matplotlib.pyplot as plt
# TODO: subclass np.ndarray?
class ShiftMatrix(object):
"""
Thin wrapper around a numpy matrix to support shifting along the second
axis and padding with zeros.
"""
def __init__(self, X, L):
"""
X : numpy matrix
L : int, largest shift
"""
# ShiftMatrix behaves like the original matrix
self.shape = X.shape
self.size = X.size
# Padded version of X
self.L = L
self.X = np.pad(X, ((0, 0), (L, L)), mode='constant')
def shift(self, l):
if np.abs(l) > self.L:
raise ValueError('requested too large of a shift.')
r = slice(self.L - l, self.L + self.shape[1] - l)
return self.X[:, r]
def assign(self, Xnew):
self.X[:, self.L:-self.L] = Xnew
class ConvNMF(object):
def __init__(self, n_components, maxlag, tol=1e-5, n_iter_max=100,
l1_W=0.0, l1_H=0.0):
"""
l1_W (float) : strength of sparsity penalty on W
l1_H (float) : strength of sparsity penalty on H
"""
self.n_components = n_components
self.maxlag = maxlag
self.W = None
self.H = None
self.tol = 1e-4
self.n_iter_max = n_iter_max
self.l1_W = l1_W
self.l1_H = l1_H
self._shifts = np.arange(maxlag*2 + 1) - maxlag
self.loss_hist = None
def fit(self, data):
# check input
if (data < 0).any():
raise ValueError('Negative values in data to fit')
data = ShiftMatrix(data, self.maxlag)
m, n = data.shape
# initialize W and H
self.W = np.random.rand(self.maxlag*2 + 1, m, self.n_components)
self.H = ShiftMatrix(np.random.rand(self.n_components, n), self.maxlag)
# optimize
converged, itr = False, 0
# initial calculation of W gradient
loss_1, grad_W = self._compute_gW(data)
self.loss_hist = [loss_1]
for itr in trange(self.n_iter_max):
# Update H
# TODO: compute Lipshitz constant for optimal learning rate!
self.W = np.maximum(self.W - 0.00001*grad_W, 0)
# compute gradient of H
_, grad_H = self._compute_gH(data)
# Update W
# TODO: compute Lipshitz constant for optimal learning rate!
self.H.assign(np.maximum(self.H.shift(0) - 0.00001*grad_H, 0))
# compute gradient of W
loss_2, grad_W = self._compute_gW(data)
self.loss_hist += [loss_2]
# check convergence
if (loss_1 - loss_2) < self.tol:
converged = True
break
# move to next iteration
else:
loss_1 = loss_2
return self
def predict(self):
"""Return low-rank reconstruction of data.
"""
# check that W and H are fit
self._check_is_fitted()
W, H = self.W, self. H
# dimensions
m, n = W.shape[1], H.shape[1]
# preallocate result
result = np.zeros((m, n))
# iterate over lags
for w, t in zip(W, self._shifts):
result += np.dot(w, H.shift(t))
return result
def _check_is_fitted(self):
if self.W is None or self.H is None:
raise ValueError('This ConvNMF instance is not fitted yet.'
'Call \'fit\' with appropriate arguments '
'before using this method.')
def _compute_loss(self, data):
"""Root Mean Squared Error
"""
resid = (self.predict() - data.shift(0)).ravel()
return np.sqrt(np.mean(np.dot(resid, resid)))
def _compute_gW(self, data):
# compute residuals
resid = self.predict() - data.shift(0)
# TODO: replace with broadcasting
Wgrad = np.empty(self.W.shape)
for l, t in enumerate(self._shifts):
Wgrad[l] = np.dot(resid, self.H.shift(t).T)
# compute loss
r = resid.ravel()
loss = np.sqrt(np.mean(np.dot(r, r)))
return loss, Wgrad
def _compute_gH(self, data):
# compute residuals
resid = self.predict() - data.shift(0)
# compute gradient
# TODO: speed up with broadcasting
Hgrad = np.zeros(self.H.shape)
for l, t in enumerate(self._shifts):
dh = np.dot(self.W[l].T, resid)
Hgrad += _shift(dh, -t)
# compute loss
r = resid.ravel()
loss = np.sqrt(np.mean(np.dot(r, r)))
return loss, Hgrad
def seq_nmf_data(N, T, L, K):
"""Creates synthetic dataset for conv NMF
Args
----
N : number of neurons
T : number of timepoints
L : max sequence length
K : number of factors / rank
Returns
-------
data : N x T matrix
"""
# low-rank data
W, H = np.random.rand(N, K), np.random.rand(K, T)
W[W < .5] = 0
H[H < .8] = 0
lrd = np.dot(W, H)
# add a random shift to each row
lags = np.random.randint(0, L, size=N)
data = np.array([np.roll(row, l, axis=-1) for row, l in zip(lrd, lags)])
# data = lrd
return data, W, H
def _shift(X, l):
"""shifts matrix X along second axis and zero pads
"""
if l < 0:
return np.pad(X, ((0, 0), (0, -l)), mode='constant')[:, -l:]
elif l > 0:
return np.pad(X, ((0, 0), (l, 0)), mode='constant')[:, :-l]
else:
return X
if __name__ == '__main__':
data, W, H = seq_nmf_data(100, 101, 10, 5)
losses = []
for k in range(1, 10):
model = ConvNMF(k, 15).fit(data)
plt.plot(model.loss_hist)
losses.append(model.loss_hist[-1])
plt.figure()
plt.plot(losses)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment