Created
April 26, 2017 16:14
-
-
Save ogrisel/5f2d31bc5e7df852b4ca63f5f6049f42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Comparision of dask_glm and scikit-learn on the [SUSY dataset](https://archive.ics.uci.edu/ml/datasets/SUSY)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"\n", | |
"import dask\n", | |
"from distributed import Client\n", | |
"import dask.array as da\n", | |
"from sklearn import linear_model\n", | |
"from dask_glm.estimators import LogisticRegression" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" <th>10</th>\n", | |
" <th>11</th>\n", | |
" <th>12</th>\n", | |
" <th>13</th>\n", | |
" <th>14</th>\n", | |
" <th>15</th>\n", | |
" <th>16</th>\n", | |
" <th>17</th>\n", | |
" <th>18</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.0</td>\n", | |
" <td>0.972861</td>\n", | |
" <td>0.653855</td>\n", | |
" <td>1.176225</td>\n", | |
" <td>1.157156</td>\n", | |
" <td>-1.739873</td>\n", | |
" <td>-0.874309</td>\n", | |
" <td>0.567765</td>\n", | |
" <td>-0.175000</td>\n", | |
" <td>0.810061</td>\n", | |
" <td>-0.252552</td>\n", | |
" <td>1.921887</td>\n", | |
" <td>0.889637</td>\n", | |
" <td>0.410772</td>\n", | |
" <td>1.145621</td>\n", | |
" <td>1.932632</td>\n", | |
" <td>0.994464</td>\n", | |
" <td>1.367815</td>\n", | |
" <td>0.040714</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1.0</td>\n", | |
" <td>1.667973</td>\n", | |
" <td>0.064191</td>\n", | |
" <td>-1.225171</td>\n", | |
" <td>0.506102</td>\n", | |
" <td>-0.338939</td>\n", | |
" <td>1.672543</td>\n", | |
" <td>3.475464</td>\n", | |
" <td>-1.219136</td>\n", | |
" <td>0.012955</td>\n", | |
" <td>3.775174</td>\n", | |
" <td>1.045977</td>\n", | |
" <td>0.568051</td>\n", | |
" <td>0.481928</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.448410</td>\n", | |
" <td>0.205356</td>\n", | |
" <td>1.321893</td>\n", | |
" <td>0.377584</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1.0</td>\n", | |
" <td>0.444840</td>\n", | |
" <td>-0.134298</td>\n", | |
" <td>-0.709972</td>\n", | |
" <td>0.451719</td>\n", | |
" <td>-1.613871</td>\n", | |
" <td>-0.768661</td>\n", | |
" <td>1.219918</td>\n", | |
" <td>0.504026</td>\n", | |
" <td>1.831248</td>\n", | |
" <td>-0.431385</td>\n", | |
" <td>0.526283</td>\n", | |
" <td>0.941514</td>\n", | |
" <td>1.587535</td>\n", | |
" <td>2.024308</td>\n", | |
" <td>0.603498</td>\n", | |
" <td>1.562374</td>\n", | |
" <td>1.135454</td>\n", | |
" <td>0.180910</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1.0</td>\n", | |
" <td>0.381256</td>\n", | |
" <td>-0.976145</td>\n", | |
" <td>0.693152</td>\n", | |
" <td>0.448959</td>\n", | |
" <td>0.891753</td>\n", | |
" <td>-0.677328</td>\n", | |
" <td>2.033060</td>\n", | |
" <td>1.533041</td>\n", | |
" <td>3.046260</td>\n", | |
" <td>-1.005285</td>\n", | |
" <td>0.569386</td>\n", | |
" <td>1.015211</td>\n", | |
" <td>1.582217</td>\n", | |
" <td>1.551914</td>\n", | |
" <td>0.761215</td>\n", | |
" <td>1.715464</td>\n", | |
" <td>1.492257</td>\n", | |
" <td>0.090719</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1.0</td>\n", | |
" <td>1.309996</td>\n", | |
" <td>-0.690089</td>\n", | |
" <td>-0.676259</td>\n", | |
" <td>1.589283</td>\n", | |
" <td>-0.693326</td>\n", | |
" <td>0.622907</td>\n", | |
" <td>1.087562</td>\n", | |
" <td>-0.381742</td>\n", | |
" <td>0.589204</td>\n", | |
" <td>1.365479</td>\n", | |
" <td>1.179295</td>\n", | |
" <td>0.968218</td>\n", | |
" <td>0.728563</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>1.083158</td>\n", | |
" <td>0.043429</td>\n", | |
" <td>1.154854</td>\n", | |
" <td>0.094859</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 7 \\\n", | |
"0 0.0 0.972861 0.653855 1.176225 1.157156 -1.739873 -0.874309 0.567765 \n", | |
"1 1.0 1.667973 0.064191 -1.225171 0.506102 -0.338939 1.672543 3.475464 \n", | |
"2 1.0 0.444840 -0.134298 -0.709972 0.451719 -1.613871 -0.768661 1.219918 \n", | |
"3 1.0 0.381256 -0.976145 0.693152 0.448959 0.891753 -0.677328 2.033060 \n", | |
"4 1.0 1.309996 -0.690089 -0.676259 1.589283 -0.693326 0.622907 1.087562 \n", | |
"\n", | |
" 8 9 10 11 12 13 14 \\\n", | |
"0 -0.175000 0.810061 -0.252552 1.921887 0.889637 0.410772 1.145621 \n", | |
"1 -1.219136 0.012955 3.775174 1.045977 0.568051 0.481928 0.000000 \n", | |
"2 0.504026 1.831248 -0.431385 0.526283 0.941514 1.587535 2.024308 \n", | |
"3 1.533041 3.046260 -1.005285 0.569386 1.015211 1.582217 1.551914 \n", | |
"4 -0.381742 0.589204 1.365479 1.179295 0.968218 0.728563 0.000000 \n", | |
"\n", | |
" 15 16 17 18 \n", | |
"0 1.932632 0.994464 1.367815 0.040714 \n", | |
"1 0.448410 0.205356 1.321893 0.377584 \n", | |
"2 0.603498 1.562374 1.135454 0.180910 \n", | |
"3 0.761215 1.715464 1.492257 0.090719 \n", | |
"4 1.083158 0.043429 1.154854 0.094859 " | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = pd.read_csv(\"SUSY.csv.gz\", header=None)\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5000000" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(df)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We have 5,000,000 rows of all-numeric data. We'll skip any feature engineering and preprocessing." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"y = df[0].values\n", | |
"X = df.drop(0, axis=1).values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"C = 10 # for scikit-learn\n", | |
"λ = 1 / C # for dask_glm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.preprocessing import scale\n", | |
"\n", | |
"X = scale(X)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Scikit-learn\n", | |
"\n", | |
"First, we run scikit-learn's `LogisticRegression` on the full dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 8s, sys: 52 ms, total: 1min 8s\n", | |
"Wall time: 1min 8s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"lm = linear_model.LogisticRegression(penalty='l1', C=C, solver='saga')\n", | |
"lm.fit(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 936 ms, sys: 316 ms, total: 1.25 s\n", | |
"Wall time: 780 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.78832060000000004" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"lm.score(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# %%time\n", | |
"# lm = linear_model.LogisticRegression(penalty='l1', C=C)\n", | |
"# lm.fit(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# %%time\n", | |
"# lm.score(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1.59889253e+00, -2.16531507e-04, -2.07209382e-03,\n", | |
" 3.07578963e-01, 1.31631754e-03, -7.76999581e-05,\n", | |
" 4.08804094e+00, 3.57268409e-03, -3.64687752e-01,\n", | |
" 3.18132406e-01, 1.27980300e-01, -9.35147191e-01,\n", | |
" -8.07122679e-01, 8.41146611e-02, -1.26453868e+00,\n", | |
" 3.32235015e-01, -2.71631453e-01, 2.17774653e-01]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lm.coef_" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Dask GLM\n", | |
"\n", | |
"Now for the dask-glm version." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"distributed.deploy.local - INFO - To start diagnostics web server please install Bokeh\n" | |
] | |
} | |
], | |
"source": [ | |
"client = Client()\n", | |
"\n", | |
"# dask\n", | |
"K = 100000\n", | |
"dX = da.from_array(X, chunks=(K, X.shape[-1]))\n", | |
"dy = da.from_array(y, chunks=(K,))\n", | |
"\n", | |
"dX, dy = dask.persist(X, y)\n", | |
"client.rebalance([X, y])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Converged! 6\n", | |
"CPU times: user 17.2 s, sys: 26 s, total: 43.1 s\n", | |
"Wall time: 6min 3s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dk = LogisticRegression()\n", | |
"dk.fit(dX, dy)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 532 ms, sys: 328 ms, total: 860 ms\n", | |
"Wall time: 438 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.78832460000000004" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dk.score(dX, dy)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 1.58671772e+00, -1.90417079e-04, -2.01987034e-03,\n", | |
" 3.04813020e-01, 1.38868265e-03, -1.10436676e-04,\n", | |
" 4.05569762e+00, 3.50754224e-03, -3.61962379e-01,\n", | |
" 3.15746332e-01, 1.26745609e-01, -9.27452934e-01,\n", | |
" -8.00558632e-01, 8.32899295e-02, -1.25427689e+00,\n", | |
" 3.29617974e-01, -2.69505213e-01, 2.15981990e-01])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dk.coef_" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"| Library | Training time | Score |\n", | |
"| -------------| ------------- | ----- |\n", | |
"| dask-glm | 1:08 | .788 |\n", | |
"| scikit-learn | 6:01 | .788 |" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The saga fit is not perfect though (accuracy is slightly lower and the coefficients not identical):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.032343321467148911" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.max(np.abs(dk.coef_ - lm.coef_))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1.21748057e-02, 2.61144277e-05, 5.22234856e-05,\n", | |
" 2.76594302e-03, 7.23651159e-05, 3.27367177e-05,\n", | |
" 3.23433215e-02, 6.51418546e-05, 2.72537334e-03,\n", | |
" 2.38607405e-03, 1.23469089e-03, 7.69425681e-03,\n", | |
" 6.56404753e-03, 8.24731672e-04, 1.02617860e-02,\n", | |
" 2.61704189e-03, 2.12624013e-03, 1.79266276e-03]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.abs(dk.coef_ - lm.coef_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment