Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created April 26, 2017 16:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ogrisel/5f2d31bc5e7df852b4ca63f5f6049f42 to your computer and use it in GitHub Desktop.
Save ogrisel/5f2d31bc5e7df852b4ca63f5f6049f42 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Comparision of dask_glm and scikit-learn on the [SUSY dataset](https://archive.ics.uci.edu/ml/datasets/SUSY)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import dask\n",
"from distributed import Client\n",
"import dask.array as da\n",
"from sklearn import linear_model\n",
"from dask_glm.estimators import LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" <th>12</th>\n",
" <th>13</th>\n",
" <th>14</th>\n",
" <th>15</th>\n",
" <th>16</th>\n",
" <th>17</th>\n",
" <th>18</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>0.972861</td>\n",
" <td>0.653855</td>\n",
" <td>1.176225</td>\n",
" <td>1.157156</td>\n",
" <td>-1.739873</td>\n",
" <td>-0.874309</td>\n",
" <td>0.567765</td>\n",
" <td>-0.175000</td>\n",
" <td>0.810061</td>\n",
" <td>-0.252552</td>\n",
" <td>1.921887</td>\n",
" <td>0.889637</td>\n",
" <td>0.410772</td>\n",
" <td>1.145621</td>\n",
" <td>1.932632</td>\n",
" <td>0.994464</td>\n",
" <td>1.367815</td>\n",
" <td>0.040714</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>1.667973</td>\n",
" <td>0.064191</td>\n",
" <td>-1.225171</td>\n",
" <td>0.506102</td>\n",
" <td>-0.338939</td>\n",
" <td>1.672543</td>\n",
" <td>3.475464</td>\n",
" <td>-1.219136</td>\n",
" <td>0.012955</td>\n",
" <td>3.775174</td>\n",
" <td>1.045977</td>\n",
" <td>0.568051</td>\n",
" <td>0.481928</td>\n",
" <td>0.000000</td>\n",
" <td>0.448410</td>\n",
" <td>0.205356</td>\n",
" <td>1.321893</td>\n",
" <td>0.377584</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>0.444840</td>\n",
" <td>-0.134298</td>\n",
" <td>-0.709972</td>\n",
" <td>0.451719</td>\n",
" <td>-1.613871</td>\n",
" <td>-0.768661</td>\n",
" <td>1.219918</td>\n",
" <td>0.504026</td>\n",
" <td>1.831248</td>\n",
" <td>-0.431385</td>\n",
" <td>0.526283</td>\n",
" <td>0.941514</td>\n",
" <td>1.587535</td>\n",
" <td>2.024308</td>\n",
" <td>0.603498</td>\n",
" <td>1.562374</td>\n",
" <td>1.135454</td>\n",
" <td>0.180910</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>0.381256</td>\n",
" <td>-0.976145</td>\n",
" <td>0.693152</td>\n",
" <td>0.448959</td>\n",
" <td>0.891753</td>\n",
" <td>-0.677328</td>\n",
" <td>2.033060</td>\n",
" <td>1.533041</td>\n",
" <td>3.046260</td>\n",
" <td>-1.005285</td>\n",
" <td>0.569386</td>\n",
" <td>1.015211</td>\n",
" <td>1.582217</td>\n",
" <td>1.551914</td>\n",
" <td>0.761215</td>\n",
" <td>1.715464</td>\n",
" <td>1.492257</td>\n",
" <td>0.090719</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.0</td>\n",
" <td>1.309996</td>\n",
" <td>-0.690089</td>\n",
" <td>-0.676259</td>\n",
" <td>1.589283</td>\n",
" <td>-0.693326</td>\n",
" <td>0.622907</td>\n",
" <td>1.087562</td>\n",
" <td>-0.381742</td>\n",
" <td>0.589204</td>\n",
" <td>1.365479</td>\n",
" <td>1.179295</td>\n",
" <td>0.968218</td>\n",
" <td>0.728563</td>\n",
" <td>0.000000</td>\n",
" <td>1.083158</td>\n",
" <td>0.043429</td>\n",
" <td>1.154854</td>\n",
" <td>0.094859</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 \\\n",
"0 0.0 0.972861 0.653855 1.176225 1.157156 -1.739873 -0.874309 0.567765 \n",
"1 1.0 1.667973 0.064191 -1.225171 0.506102 -0.338939 1.672543 3.475464 \n",
"2 1.0 0.444840 -0.134298 -0.709972 0.451719 -1.613871 -0.768661 1.219918 \n",
"3 1.0 0.381256 -0.976145 0.693152 0.448959 0.891753 -0.677328 2.033060 \n",
"4 1.0 1.309996 -0.690089 -0.676259 1.589283 -0.693326 0.622907 1.087562 \n",
"\n",
" 8 9 10 11 12 13 14 \\\n",
"0 -0.175000 0.810061 -0.252552 1.921887 0.889637 0.410772 1.145621 \n",
"1 -1.219136 0.012955 3.775174 1.045977 0.568051 0.481928 0.000000 \n",
"2 0.504026 1.831248 -0.431385 0.526283 0.941514 1.587535 2.024308 \n",
"3 1.533041 3.046260 -1.005285 0.569386 1.015211 1.582217 1.551914 \n",
"4 -0.381742 0.589204 1.365479 1.179295 0.968218 0.728563 0.000000 \n",
"\n",
" 15 16 17 18 \n",
"0 1.932632 0.994464 1.367815 0.040714 \n",
"1 0.448410 0.205356 1.321893 0.377584 \n",
"2 0.603498 1.562374 1.135454 0.180910 \n",
"3 0.761215 1.715464 1.492257 0.090719 \n",
"4 1.083158 0.043429 1.154854 0.094859 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"SUSY.csv.gz\", header=None)\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5000000"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have 5,000,000 rows of all-numeric data. We'll skip any feature engineering and preprocessing."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y = df[0].values\n",
"X = df.drop(0, axis=1).values"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"C = 10 # for scikit-learn\n",
"λ = 1 / C # for dask_glm"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.preprocessing import scale\n",
"\n",
"X = scale(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scikit-learn\n",
"\n",
"First, we run scikit-learn's `LogisticRegression` on the full dataset."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 8s, sys: 52 ms, total: 1min 8s\n",
"Wall time: 1min 8s\n"
]
}
],
"source": [
"%%time\n",
"lm = linear_model.LogisticRegression(penalty='l1', C=C, solver='saga')\n",
"lm.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 936 ms, sys: 316 ms, total: 1.25 s\n",
"Wall time: 780 ms\n"
]
},
{
"data": {
"text/plain": [
"0.78832060000000004"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"lm.score(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# %%time\n",
"# lm = linear_model.LogisticRegression(penalty='l1', C=C)\n",
"# lm.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# %%time\n",
"# lm.score(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.59889253e+00, -2.16531507e-04, -2.07209382e-03,\n",
" 3.07578963e-01, 1.31631754e-03, -7.76999581e-05,\n",
" 4.08804094e+00, 3.57268409e-03, -3.64687752e-01,\n",
" 3.18132406e-01, 1.27980300e-01, -9.35147191e-01,\n",
" -8.07122679e-01, 8.41146611e-02, -1.26453868e+00,\n",
" 3.32235015e-01, -2.71631453e-01, 2.17774653e-01]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dask GLM\n",
"\n",
"Now for the dask-glm version."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"distributed.deploy.local - INFO - To start diagnostics web server please install Bokeh\n"
]
}
],
"source": [
"client = Client()\n",
"\n",
"# dask\n",
"K = 100000\n",
"dX = da.from_array(X, chunks=(K, X.shape[-1]))\n",
"dy = da.from_array(y, chunks=(K,))\n",
"\n",
"dX, dy = dask.persist(X, y)\n",
"client.rebalance([X, y])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converged! 6\n",
"CPU times: user 17.2 s, sys: 26 s, total: 43.1 s\n",
"Wall time: 6min 3s\n"
]
}
],
"source": [
"%%time\n",
"dk = LogisticRegression()\n",
"dk.fit(dX, dy)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 532 ms, sys: 328 ms, total: 860 ms\n",
"Wall time: 438 ms\n"
]
},
{
"data": {
"text/plain": [
"0.78832460000000004"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"dk.score(dX, dy)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1.58671772e+00, -1.90417079e-04, -2.01987034e-03,\n",
" 3.04813020e-01, 1.38868265e-03, -1.10436676e-04,\n",
" 4.05569762e+00, 3.50754224e-03, -3.61962379e-01,\n",
" 3.15746332e-01, 1.26745609e-01, -9.27452934e-01,\n",
" -8.00558632e-01, 8.32899295e-02, -1.25427689e+00,\n",
" 3.29617974e-01, -2.69505213e-01, 2.15981990e-01])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dk.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Library | Training time | Score |\n",
"| -------------| ------------- | ----- |\n",
"| dask-glm | 1:08 | .788 |\n",
"| scikit-learn | 6:01 | .788 |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The saga fit is not perfect though (accuracy is slightly lower and the coefficients not identical):"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.032343321467148911"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.max(np.abs(dk.coef_ - lm.coef_))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.21748057e-02, 2.61144277e-05, 5.22234856e-05,\n",
" 2.76594302e-03, 7.23651159e-05, 3.27367177e-05,\n",
" 3.23433215e-02, 6.51418546e-05, 2.72537334e-03,\n",
" 2.38607405e-03, 1.23469089e-03, 7.69425681e-03,\n",
" 6.56404753e-03, 8.24731672e-04, 1.02617860e-02,\n",
" 2.61704189e-03, 2.12624013e-03, 1.79266276e-03]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.abs(dk.coef_ - lm.coef_)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment