Created
February 22, 2017 03:34
-
-
Save mrocklin/2d017346d46cdb1dbbe5fcc72a9949b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src=\"http://dask.readthedocs.io/en/latest/_images/dask_horizontal.svg\"\n", | |
" align=\"right\"\n", | |
" width=\"30%\"\n", | |
" alt=\"Dask logo\">\n", | |
"\n", | |
"\n", | |
"Model Parallelism with SKLearn and Dask\n", | |
"=======================\n", | |
"\n", | |
"\n", | |
"<img src=\"https://avatars2.githubusercontent.com/u/365630?v=3&s=400\"\n", | |
" align=\"right\"\n", | |
" width=\"25%\"\n", | |
" alt=\"SKLearn logo\">\n", | |
"\n", | |
"\n", | |
"*How do we choose the right parameters for a machine learning pipeline?*\n", | |
"\n", | |
"This notebook takes a standard example from the Scikit-Learn documententation and parallelizes it using a Dask-powered `GridSearchCV` function, which is a drop in replacement. We achieve significant speedup on a cluster on an important problem just by changing an import.\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### SKLearn example\n", | |
"\n", | |
"Taken from: http://scikit-learn.org/stable/auto_examples/plot_digits_pipe.html#sphx-glr-auto-examples-plot-digits-pipe-py" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Code source: Gaël Varoquaux\n", | |
"# Modified for documentation by Jaques Grobler\n", | |
"# License: BSD 3 clause\n", | |
"\n", | |
"\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"from sklearn import linear_model, decomposition, datasets\n", | |
"from sklearn.pipeline import Pipeline\n", | |
"from dklearn.model_selection import DaskGridSearchCV as GridSearchCV\n", | |
"\n", | |
"logistic = linear_model.LogisticRegression()\n", | |
"\n", | |
"pca = decomposition.PCA()\n", | |
"pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)])\n", | |
"\n", | |
"digits = datasets.load_digits()\n", | |
"X_digits = digits.data\n", | |
"y_digits = digits.target" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"n_components = [10, 20, 30, 40, 50, 64]\n", | |
"Cs = np.logspace(-4, 4, 30)\n", | |
"\n", | |
"#Parameters of pipelines can be set using ‘__’ separated parameter names:\n", | |
"\n", | |
"estimator = GridSearchCV(pipe,\n", | |
" dict(pca__n_components=n_components,\n", | |
" logistic__C=Cs))\n", | |
"estimator.fit(X_digits, y_digits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"estimator.best_params_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from dask.distributed import Client\n", | |
"c = Client('localhost:8786')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment