Created
April 6, 2017 13:45
-
-
Save mrocklin/4e486064882cce630ffb4ee4e39bc333 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<Client: scheduler='tcp://localhost:8786' processes=7 cores=56>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from dask import persist\n", | |
"from dask.distributed import Client, progress\n", | |
"client = Client('localhost:8786')\n", | |
"client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: overflow encountered in exp\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
} | |
], | |
"source": [ | |
"import dask\n", | |
"import dask.array as da\n", | |
"import numpy as np\n", | |
"from scipy.optimize import fmin_l_bfgs_b\n", | |
"\n", | |
"def make_classification(n_samples=1000, n_features=100, n_informative=2, scale=1.0, chunksize=100):\n", | |
" X = da.random.normal(0, 1, size=(n_samples, n_features), \n", | |
" chunks=(chunksize, n_features))\n", | |
" informative_idx = np.random.choice(n_features, n_informative)\n", | |
" beta = (np.random.random(n_features) - 1) * scale\n", | |
" z0 = X[:, informative_idx].dot(beta[informative_idx])\n", | |
" y = da.random.random(z0.shape, chunks=(chunksize,)) < 1 / (1 + da.exp(-z0))\n", | |
" return X, y\n", | |
"\n", | |
"\n", | |
"# make dataset\n", | |
"X, y = make_classification(n_samples=100000000, n_features=50, \n", | |
" n_informative=2, chunksize=500000)\n", | |
"\n", | |
"X, y = dask.persist(X, y)\n", | |
"\n", | |
"# logistic\n", | |
"def sigmoid(x):\n", | |
" '''Sigmoid function of x.'''\n", | |
" return 1 / (1 + da.exp(-x))\n", | |
"\n", | |
"def compute_logistic_loss_grad(beta, X, y):\n", | |
" Xbeta = X.dot(beta)\n", | |
" # loss\n", | |
" eXbeta = da.exp(Xbeta)\n", | |
" loss_fn = (da.log1p(eXbeta)).sum() - da.dot(y, Xbeta)\n", | |
" # gradient\n", | |
" p = sigmoid(Xbeta)\n", | |
" gradient_fn = da.dot(X.T, p - y)\n", | |
" loss, gradient = dask.compute(loss_fn, gradient_fn)\n", | |
" return loss, gradient.copy()\n", | |
"\n", | |
"n, p = X.shape\n", | |
"beta = np.zeros(p)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dask.array<da.random.normal, shape=(100000000, 50), dtype=float64, chunksize=(500000, 50)>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"40.0" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X.nbytes / 1e9" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: overflow encountered in true_divide\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.86 s, sys: 35.1 ms, total: 1.89 s\n", | |
"Wall time: 8.82 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"with dask.set_options(fuse_ave_width=0): # optimizations slows this down\n", | |
" new_beta_dask, loss_dask, info_dask = fmin_l_bfgs_b(\n", | |
" compute_logistic_loss_grad, beta, fprime=None,\n", | |
" args=(X, y),\n", | |
" iprint=0, pgtol=1e-14, maxiter=10)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 3.36986804e-04, -2.58183289e-05, -8.57693197e-05,\n", | |
" 1.30980861e-04, -2.10530174e-04, -1.74379821e-04,\n", | |
" -2.58747778e-04, -7.79500575e-06, -2.34594893e-05,\n", | |
" 5.17511247e-05, 2.32281649e-04, 1.00570180e-04,\n", | |
" 5.58777107e-06, 2.28895987e-04, -1.89364329e-04,\n", | |
" 5.20665150e-05, 1.68109365e-04, 3.71687228e-04,\n", | |
" 1.14815910e-04, 3.12092343e-04, 3.11045287e-04,\n", | |
" 3.12777296e-04, -4.30556434e-01, -4.10634311e-05,\n", | |
" -2.95176733e-04, -7.64938745e-06, -9.16969603e-01,\n", | |
" -1.57075385e-04, 1.16032906e-04, -2.60510424e-04,\n", | |
" 2.29526165e-05, -3.78676990e-04, 7.99983434e-05,\n", | |
" 1.29415052e-04, 4.31178181e-04, 3.28448348e-04,\n", | |
" 2.06186917e-04, -1.37779101e-04, -1.74945595e-04,\n", | |
" -1.36139608e-04, -3.76378757e-05, -2.63300072e-04,\n", | |
" -1.76055108e-04, -5.29483078e-05, -1.50231453e-04,\n", | |
" 2.56707029e-04, -1.89528783e-04, 3.83597204e-04,\n", | |
" 1.27296870e-04, -1.28490118e-04])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"new_beta_dask" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment