Last active
May 25, 2018 11:36
-
-
Save valentina-s/5e35c5b3f7ed785e00e79bfee222b725 to your computer and use it in GitHub Desktop.
daskNMF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"A brute force implementation of the multiplicative method for NMF in dask: i.e. converting all array operations to `dask.array` operations. This version does not include the regularization." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from scipy import linalg\n", | |
"import matplotlib.pyplot as plt\n", | |
"import os\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"For testing we will use the face dataset used in scikit-learn, which comes from the “Labeled Faces in the Wild†dataset, also known as LFW:\n", | |
"\n", | |
"[http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz](http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from dask.distributed import Client, progress\n", | |
"c = Client()\n", | |
"c.restart()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# read the faces dataset\n", | |
"path = 'lfw_funneled'\n", | |
"from dask.array.image import imread\n", | |
"import dask.array as da\n", | |
"faces = imread(os.path.join(path,'*','*.jpg'))\n", | |
"N,m,n,d = faces[:,::4,::4,:].shape\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = faces[:,::4,::4,0].reshape((faces.shape[0],-1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_da = c.persist(data[:100,:])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# compute for sklearn\n", | |
"X = data[:100,:].compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time as time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# to avoid deviding by zero\n", | |
"EPSILON = np.finfo(np.float32).eps" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Numpy Updates:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def update_H(M,H,W):\n", | |
" denominator = (np.dot(W.T,np.dot(W,H)))\n", | |
" denominator[denominator == 0] = EPSILON\n", | |
" H_new = H*np.dot(W.T,M)/denominator\n", | |
" return(H_new)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def update_W(M,H,W):\n", | |
" denominator = (np.dot(W,np.dot(H,H.T)))\n", | |
" denominator[denominator == 0] = EPSILON\n", | |
" W_new = W*np.dot(M,H.T)/denominator\n", | |
" return(W_new)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Dask Updates:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def update_H_da(M,H,W):\n", | |
" denominator = da.dot(W.T,da.dot(W,H))\n", | |
" denominator = da.where(denominator == 0,EPSILON,denominator) \n", | |
" H_new = H*da.dot(W.T,M)/denominator\n", | |
" return(H_new)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def update_W_da(M,H,W):\n", | |
" denominator = da.dot(W,da.dot(H,H.T))\n", | |
" denominator = da.where(denominator == 0,EPSILON,denominator) \n", | |
" W_new = W*da.dot(M,H.T)/denominator\n", | |
" return(W_new)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#print(update_W(M,H,W).shape)\n", | |
"#print(update_H(M,H,W).shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# loss\n", | |
"def Frobenius_loss(M,H,W):\n", | |
" return(linalg(M - np.dot(W,H)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# test da.where\n", | |
"ones = da.zeros((3,3),chunks = (3,3))\n", | |
"ones = da.where(ones==0,EPSILON,ones)\n", | |
"ones.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def initialize_da(X, k, init='random'):\n", | |
" n_components = k\n", | |
" n_samples, n_features = X.shape\n", | |
" if init == 'random':\n", | |
" avg = da.sqrt(X.mean() / n_components)\n", | |
" da.random.seed(42)\n", | |
" H = avg * da.random.normal(0,1,size=(n_components, n_features),chunks=(n_components,X.chunks[1][0]))\n", | |
" W = avg * da.random.normal(0,1,size=(n_samples, n_components),chunks=(n_samples,n_components))\n", | |
" \n", | |
" da.fabs(H, H)\n", | |
" da.fabs(W, W)\n", | |
" return W, H" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# NNDSVD/A initialization from sklearn\n", | |
"def initialize(X,k,init):\n", | |
"\n", | |
" from scipy.linalg import svd\n", | |
" n_components = k\n", | |
" U, S, V = svd(X, full_matrices = False)\n", | |
" W, H = np.zeros(U.shape), np.zeros(V.shape)\n", | |
"\n", | |
" # The leading singular triplet is non-negative\n", | |
" # so it can be used as is for initialization.\n", | |
" W[:, 0] = np.sqrt(S[0]) * np.abs(U[:, 0])\n", | |
" H[0, :] = np.sqrt(S[0]) * np.abs(V[0, :])\n", | |
" \n", | |
" def norm(x):\n", | |
" x = x.ravel() \n", | |
" return(np.dot(x,x))\n", | |
"\n", | |
" for j in range(1, n_components):\n", | |
" x, y = U[:, j], V[j, :]\n", | |
"\n", | |
" # extract positive and negative parts of column vectors\n", | |
" x_p, y_p = np.maximum(x, 0), np.maximum(y, 0)\n", | |
" x_n, y_n = np.abs(np.minimum(x, 0)), np.abs(np.minimum(y, 0))\n", | |
"\n", | |
" # and their norms\n", | |
" x_p_nrm, y_p_nrm = norm(x_p), norm(y_p)\n", | |
" x_n_nrm, y_n_nrm = norm(x_n), norm(y_n)\n", | |
"\n", | |
" m_p, m_n = x_p_nrm * y_p_nrm, x_n_nrm * y_n_nrm\n", | |
"\n", | |
" # choose update\n", | |
" if m_p > m_n:\n", | |
" u = x_p / x_p_nrm\n", | |
" v = y_p / y_p_nrm\n", | |
" sigma = m_p\n", | |
" else:\n", | |
" u = x_n / x_n_nrm\n", | |
" v = y_n / y_n_nrm\n", | |
" sigma = m_n\n", | |
"\n", | |
" lbd = np.sqrt(S[j] * sigma)\n", | |
" W[:, j] = lbd * u\n", | |
" H[j, :] = lbd * v\n", | |
" \n", | |
" eps=1e-6\n", | |
"\n", | |
" if init == 'nndsvd':\n", | |
" W[W < eps] = 0\n", | |
" H[H < eps] = 0\n", | |
" \n", | |
" if init == 'nndsvda':\n", | |
" avg = X.mean()\n", | |
" W[W == 0] = avg\n", | |
" H[H == 0] = avg\n", | |
" \n", | |
" return(W,H)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"def initialize_random(X,k,init):\n", | |
" n_components = k\n", | |
" n_samples, n_features = X.shape\n", | |
" if init == 'random':\n", | |
" avg = np.sqrt(X.mean() / n_components)\n", | |
" np.random.seed(42)\n", | |
" H = avg * np.random.normal(0,1,size=(n_components, n_features))\n", | |
" W = avg * np.random.normal(0,1,size=(n_samples, n_components))\n", | |
" \n", | |
" np.fabs(H, H)\n", | |
" np.fabs(W, W)\n", | |
" return W, H\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"k = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"W, H = initialize_da(X_da,k,init='random')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# fitting function\n", | |
"EPSILON = np.finfo(np.float32).eps\n", | |
"def fit(M,k,nofit):\n", | |
" \n", | |
" # initialize H and W\n", | |
" #u,s,vt = linalg.svd(M)\n", | |
" #W = u[:,:k]\n", | |
" #H = np.dot(np.diag(s[:k]),vt[:k,:])\n", | |
" #print(s[:k])\n", | |
" \n", | |
" W, H = initialize(M, k, init='nndsvda')\n", | |
" #W, H = initialize_random(M,k,init='random')\n", | |
" \n", | |
" err = []\n", | |
" for it in range(nofit):\n", | |
" W = update_W(M,H,W)\n", | |
" #print(np.sum(np.isnan(W)))\n", | |
" H = update_H(M,H,W)\n", | |
" err.append(linalg.norm(M - np.dot(W,H)))\n", | |
" print('Iteration '+str(it)+': error = '+ str(err[it]))\n", | |
" return(W, H, err)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# fitting function\n", | |
"def fit_da(M,k,nofit):\n", | |
" \n", | |
" from dask import compute\n", | |
" \n", | |
" # initialize H and W\n", | |
" #u,s,vt = da.linalg.svd(M)\n", | |
" #uk,sk,vtk = compute(u[:,k],s[:k],vt[:k])\n", | |
" #W = uk\n", | |
" #H = da.dot(np.diag(sk),vtk[:k,:])\n", | |
" \n", | |
" W, H = initialize_da(M,k,init='random')\n", | |
" \n", | |
" err = []\n", | |
" for it in range(nofit):\n", | |
" W = update_W_da(M,H,W)\n", | |
" H = update_H_da(M,H,W)\n", | |
" \n", | |
" err.append(da.linalg.norm(M - da.dot(W,H))) \n", | |
" \n", | |
" return(W,H,err)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_da = X_da.rechunk((100,441))\n", | |
"X_da" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nofit = 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"W, H, err = fit(X,100,1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"H" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"plt.plot(err)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_fitted = np.dot(W,H).reshape(-1,m,n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# plot the results\n", | |
"plt.figure(figsize = (10,5))\n", | |
"plt.subplot(1,2,1)\n", | |
"plt.imshow(faces[1,::4,::4,0],cmap = 'gray')\n", | |
"plt.title('Raw')\n", | |
" \n", | |
"plt.subplot(1,2,2)\n", | |
"plt.imshow(data_fitted[1,:,:],cmap = 'gray')\n", | |
"plt.title('Fitted')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# compare to sklear multiplicative method \n", | |
"# only latest versions have the multiplicative method\n", | |
"# performance seems worse than the coordinate descent method\n", | |
" \n", | |
"from sklearn.decomposition import NMF\n", | |
"nmf = NMF(n_components = 100,init = 'nndsvda',solver = 'mu',max_iter = 200)\n", | |
"W = nmf.fit_transform(X)\n", | |
"H = nmf.components_\n", | |
"data_fitted_sk = np.dot(W,H)\n", | |
"\n", | |
"#plot the results\n", | |
"plt.figure(figsize = (10,5))\n", | |
"plt.subplot(1,2,1)\n", | |
"plt.imshow(faces[1,::4,::4,0],cmap = 'gray')\n", | |
"plt.title('Raw')\n", | |
" \n", | |
"plt.subplot(1,2,2)\n", | |
"plt.imshow(data_fitted_sk.reshape(-1,m,n)[1,:,:],cmap = 'gray')\n", | |
"plt.title('Fitted')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# plot the results\n", | |
"u = H.reshape((-1,m,n))\n", | |
"plt.figure(figsize = (10,5))\n", | |
"for i in range(10):\n", | |
" plt.subplot(2,5,i+1)\n", | |
" # we are rescaling between 0 and 1 before plotting\n", | |
" plt.imshow(u[i,:,:],cmap = 'gray')\n", | |
" plt.title('Mode '+str(i+1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"W, H, err = fit_da(X_da,100,1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"H.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# plot the results\n", | |
"u = H.compute().reshape((-1,m,n))\n", | |
"plt.figure(figsize = (10,5))\n", | |
"for i in range(10):\n", | |
" plt.subplot(2,5,i+1)\n", | |
" # we are rescaling between 0 and 1 before plotting\n", | |
" plt.imshow(u[i,:,:],cmap = 'gray')\n", | |
" plt.title('Mode '+str(i+1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"plt.plot(err)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"da.random.seed(42)\n", | |
"da.random.RandomState(42).normal(0,1,(1,),(1,)).compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)\n", | |
"np.random.RandomState(42).normal(0,1,(1,))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# how do I get the same streams???\n", | |
"help(da.random.RandomState(42).normal)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment