Skip to content

Instantly share code, notes, and snippets.

@domvwt
Created February 2, 2021 17:55
Show Gist options
  • Save domvwt/078ac4da72ada3d2229747ad827bf47f to your computer and use it in GitHub Desktop.
Save domvwt/078ac4da72ada3d2229747ad827bf47f to your computer and use it in GitHub Desktop.
vaex-optuna.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "vaex-optuna.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyP0rg+FmL1BUJYEEpNoox/b",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/domvwt/078ac4da72ada3d2229747ad827bf47f/vaex-optuna.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ywOOu8YGu0Mj"
},
"source": [
"# Hyperparameter Tuning with Vaex and Optuna\r\n",
"\r\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mM2JrQ4VpJaY",
"outputId": "6e80cd03-8a33-4694-ae86-daaa2719decb"
},
"source": [
"%%shell\r\n",
"echo \"Setting up environment...\"\r\n",
"if [ ! -e \"yellow_tripdata_2019-12.csv\" ]\r\n",
" then \r\n",
" echo \"...Downloading data...\" \r\n",
" wget -q \\\r\n",
" https://nyc-tlc.s3.amazonaws.com/trip+data/yellow_tripdata_2019-12.csv\r\n",
"fi\r\n",
"echo \"...Installing libraries...\"\r\n",
"pip install -Uqq vaex-core vaex-hdf5 vaex-jupyter vaex-ml catboost optuna \"ipython>=7.0.0\" &> /dev/null\r\n",
"echo \"Setup Complete!\"\r\n",
"echo \"Runtime restart may be required to load new packages.\""
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Setting up environment...\n",
"...Downloading data...\n",
"...Installing libraries...\n",
"Setup Complete!\n",
"Runtime restart may be required to load new packages.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {
"tags": []
},
"execution_count": 1
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tO6_1lOhugqt"
},
"source": [
"import gc\n",
"import vaex as vx\n",
"import vaex.jupyter as vj\n",
"import vaex.ml.catboost as cb\n",
"import optuna\n",
"import nest_asyncio\n",
"\n",
"\n",
"from pathlib import Path\n",
"from sklearn.metrics import mean_absolute_error\n",
"\n",
"\n",
"# Asyncio event loops get left running when the notebook is restarted\n",
"# Calling nest_asyncio prevents them from interfering with new sessions\n",
"nest_asyncio.apply()"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "t13hPzuV0PfH"
},
"source": [
"DATA_PATH = \"yellow_tripdata_2019-12.csv\"\r\n",
"DATA_NEW = \"yellow_tripdata_2019-12.csv.hdf5\"\r\n",
"\r\n",
"if not Path(DATA_NEW).is_file():\r\n",
" _ = vx.from_csv(DATA_PATH, convert=True)"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 885
},
"id": "c4mK0fm98r-v",
"outputId": "e0dc413c-49d7-4919-cbfe-d4e6e939ca24"
},
"source": [
"df00 = vx.open(DATA_NEW)\r\n",
"df00.info()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<style>.vaex-description pre {\n",
" max-width : 450px;\n",
" white-space : nowrap;\n",
" overflow : hidden;\n",
" text-overflow: ellipsis;\n",
" }\n",
"\n",
" .vex-description pre:hover {\n",
" max-width : initial;\n",
" white-space: pre;\n",
" }</style>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div><h2>yellow_tripdata_2019-12.csv</h2> <b>rows</b>: 6,896,317</div><div><b>path</b>: <i>/content/yellow_tripdata_2019-12.csv.hdf5</i></div><div><b>Description</b>: file exported by vaex, by user root, on date 2021-02-02 13:32:50.698030, from source /has/no/path/arrays-/has/no/path/arrays</div><h2>Columns:</h2><table class='table-striped'><thead><tr><th>column</th><th>type</th><th>unit</th><th>description</th><th>expression</th></tr></thead><tr><td>VendorID</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>tpep_pickup_datetime</td><td>str</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>tpep_dropoff_datetime</td><td>str</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>passenger_count</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>trip_distance</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>RatecodeID</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>store_and_fwd_flag</td><td>str</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>PULocationID</td><td>int64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>DOLocationID</td><td>int64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>payment_type</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>fare_amount</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>extra</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>mta_tax</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>tip_amount</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>tolls_amount</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>improvement_surcharge</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>total_amount</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr><tr><td>congestion_surcharge</td><td>float64</td><td></td><td ><pre></pre></td><td></td></tr></table><h2>Data:</h2><table>\n",
"<thead>\n",
"<tr><th># </th><th>VendorID </th><th>tpep_pickup_datetime </th><th>tpep_dropoff_datetime </th><th>passenger_count </th><th>trip_distance </th><th>RatecodeID </th><th>store_and_fwd_flag </th><th>PULocationID </th><th>DOLocationID </th><th>payment_type </th><th>fare_amount </th><th>extra </th><th>mta_tax </th><th>tip_amount </th><th>tolls_amount </th><th>improvement_surcharge </th><th>total_amount </th><th>congestion_surcharge </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td><i style='opacity: 0.6'>0</i> </td><td>1.0 </td><td>2019-12-01 00:26:58 </td><td>2019-12-01 00:41:45 </td><td>1.0 </td><td>4.2 </td><td>1.0 </td><td>N </td><td>142 </td><td>116 </td><td>2.0 </td><td>14.5 </td><td>3.0 </td><td>0.5 </td><td>0.0 </td><td>0.0 </td><td>0.3 </td><td>18.3 </td><td>2.5 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>1</i> </td><td>1.0 </td><td>2019-12-01 00:12:08 </td><td>2019-12-01 00:12:14 </td><td>1.0 </td><td>0.0 </td><td>1.0 </td><td>N </td><td>145 </td><td>145 </td><td>2.0 </td><td>2.5 </td><td>0.5 </td><td>0.5 </td><td>0.0 </td><td>0.0 </td><td>0.3 </td><td>3.8 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>2</i> </td><td>1.0 </td><td>2019-12-01 00:25:53 </td><td>2019-12-01 00:26:04 </td><td>1.0 </td><td>0.0 </td><td>1.0 </td><td>N </td><td>145 </td><td>145 </td><td>2.0 </td><td>2.5 </td><td>0.5 </td><td>0.5 </td><td>0.0 </td><td>0.0 </td><td>0.3 </td><td>3.8 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>3</i> </td><td>1.0 </td><td>2019-12-01 00:12:03 </td><td>2019-12-01 00:33:19 </td><td>2.0 </td><td>9.4 </td><td>1.0 </td><td>N </td><td>138 </td><td>25 </td><td>1.0 </td><td>28.5 </td><td>0.5 </td><td>0.5 </td><td>10.0 </td><td>0.0 </td><td>0.3 </td><td>39.8 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>4</i> </td><td>1.0 </td><td>2019-12-01 00:05:27 </td><td>2019-12-01 00:16:32 </td><td>2.0 </td><td>1.6 </td><td>1.0 </td><td>N </td><td>161 </td><td>237 </td><td>2.0 </td><td>9.0 </td><td>3.0 </td><td>0.5 </td><td>0.0 </td><td>0.0 </td><td>0.3 </td><td>12.8 </td><td>2.5 </td></tr>\n",
"<tr><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td><td>... </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>6,896,312</i></td><td>nan </td><td>2019-12-31 00:07:00 </td><td>2019-12-31 00:46:00 </td><td>nan </td><td>12.78 </td><td>nan </td><td>None </td><td>230 </td><td>72 </td><td>nan </td><td>32.32 </td><td>2.75 </td><td>0.5 </td><td>0.0 </td><td>6.12 </td><td>0.3 </td><td>41.99 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>6,896,313</i></td><td>nan </td><td>2019-12-31 00:20:00 </td><td>2019-12-31 00:47:00 </td><td>nan </td><td>18.52 </td><td>nan </td><td>None </td><td>219 </td><td>32 </td><td>nan </td><td>51.63 </td><td>2.75 </td><td>0.5 </td><td>0.0 </td><td>6.12 </td><td>0.3 </td><td>61.3 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>6,896,314</i></td><td>nan </td><td>2019-12-31 00:50:00 </td><td>2019-12-31 01:21:00 </td><td>nan </td><td>13.13 </td><td>nan </td><td>None </td><td>161 </td><td>76 </td><td>nan </td><td>38.02 </td><td>2.75 </td><td>0.5 </td><td>0.0 </td><td>6.12 </td><td>0.3 </td><td>47.69 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>6,896,315</i></td><td>nan </td><td>2019-12-31 00:38:19 </td><td>2019-12-31 01:19:37 </td><td>nan </td><td>14.51 </td><td>nan </td><td>None </td><td>230 </td><td>21 </td><td>nan </td><td>41.86 </td><td>2.75 </td><td>0.0 </td><td>0.0 </td><td>6.12 </td><td>0.3 </td><td>51.03 </td><td>0.0 </td></tr>\n",
"<tr><td><i style='opacity: 0.6'>6,896,316</i></td><td>nan </td><td>2019-12-31 00:21:00 </td><td>2019-12-31 00:56:00 </td><td>nan </td><td>-17.16 </td><td>nan </td><td>None </td><td>193 </td><td>219 </td><td>nan </td><td>44.62 </td><td>2.75 </td><td>0.5 </td><td>0.0 </td><td>0.0 </td><td>0.3 </td><td>48.17 </td><td>0.0 </td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "-7vln7ogxvwJ"
},
"source": [
"target = \"tip_amount\"\n",
"features = [x for x in df00.get_column_names() if x != target]"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TMO2QjHewgC5"
},
"source": [
"def objective(trial):\r\n",
"\r\n",
" gc.collect()\r\n",
"\r\n",
" # Take a random sample of 1,000,000 rows and split into train / test\r\n",
" df_train, df_valid = df00.sample(int(1e6)).ml.train_test_split(verbose=False)\r\n",
"\r\n",
" # Catboost model parameters\r\n",
" cbm_params = dict(\r\n",
" loss_function=\"MAE\",\r\n",
" early_stopping_rounds=50,\r\n",
" verbose=False,\r\n",
" used_ram_limit=\"8gb\",\r\n",
"\r\n",
" colsample_bylevel=trial.suggest_float(\"colsample_bylevel\", 0.01, 0.1),\r\n",
" depth=trial.suggest_int(\"depth\", 1, 12),\r\n",
" bootstrap_type=trial.suggest_categorical(\r\n",
" \"boostrap_type\", [\"Bayesian\", \"Bernoulli\", \"MVS\"]\r\n",
" ),\r\n",
" )\r\n",
"\r\n",
" if cbm_params[\"bootstrap_type\"] == \"Bayesian\":\r\n",
" cbm_params[\"bagging_temperature\"] = trial.suggest_float(\"bagging_temperature\", 0, 10)\r\n",
" elif cbm_params[\"bootstrap_type\"] == \"Bernoulli\":\r\n",
" cbm_params[\"subsample\"] = trial.suggest_float(\"subsample\", 0.1, 1)\r\n",
"\r\n",
" # Vaex wrapper parameters\r\n",
" # - Use prediction_type to determine between classification / regression\r\n",
" # - Use pool_params to pass cat_features index\r\n",
" vcb_params = dict(\r\n",
" features=features,\r\n",
" target=target,\r\n",
" \r\n",
" prediction_type=\"RawFormulaVal\",\r\n",
" params=cbm_params,\r\n",
" num_boost_round=200,\r\n",
" # pool_params=pool_params\r\n",
" ) \r\n",
"\r\n",
" cbm = cb.CatBoostModel(**vcb_params)\r\n",
" cbm.fit(df_train, evals=[df_valid])\r\n",
"\r\n",
" preds = cbm.predict(df_valid)\r\n",
" mae = mean_absolute_error(preds, df_valid[target].values)\r\n",
"\r\n",
" return mae"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SM32jp47E6AF",
"outputId": "989f8c9d-a241-446a-c5a1-68c5c10cc05b"
},
"source": [
"study = optuna.create_study(\r\n",
" storage=\"sqlite:///optuna.db\", \r\n",
" study_name=\"catboost-tips\",\r\n",
" direction=\"minimize\",\r\n",
" load_if_exists=True\r\n",
")\r\n",
"\r\n",
"study_df = study.trials_dataframe()\r\n",
"\r\n",
"if not study_df.empty:\r\n",
" trial_runs = study_df[study_df.state != \"FAIL\"].shape[0]\r\n",
"else:\r\n",
" trial_runs = 0\r\n",
"\r\n",
"n_trials = 200 - trial_runs\r\n",
"\r\n",
"print(f\"Remaining trials: {n_trials}\")\r\n",
"print(\"Starting hyperparameter search...\")\r\n",
"print(\"...Collecting garbage...\")\r\n",
"gc.collect()\r\n",
"print(\"...Running study...\")\r\n",
"study.optimize(objective, n_trials=n_trials, n_jobs=4)\r\n",
"\r\n",
"print(\"Number of finished trials: {}\".format(len(study.trials)))\r\n",
"\r\n",
"print(\"Best trial:\")\r\n",
"trial = study.best_trial\r\n",
"\r\n",
"print(\" Value: {}\".format(trial.value))\r\n",
"\r\n",
"print(\" Params: \")\r\n",
"for key, value in trial.params.items():\r\n",
" print(\" {}: {}\".format(key, value))"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[32m[I 2021-02-02 17:43:17,549]\u001b[0m Using an existing study with name 'catboost-tips' instead of creating a new one.\u001b[0m\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Remaining trials: 0\n",
"Starting hyperparameter search...\n",
"...Collecting garbage...\n",
"...Running study...\n",
"Number of finished trials: 200\n",
"Best trial:\n",
" Value: 0.7167503577102808\n",
" Params: \n",
" boostrap_type: Bernoulli\n",
" colsample_bylevel: 0.09915981893666519\n",
" depth: 12\n",
" subsample: 0.6726329766751142\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 379
},
"id": "uy9392xeIgaT",
"outputId": "888f31cd-af61-48fb-a091-be1275763935"
},
"source": [
"optuna.visualization.matplotlib.plot_slice(study);"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: ExperimentalWarning:\n",
"\n",
"plot_slice is experimental (supported from v2.2.0). The interface can change in the future.\n",
"\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x288 with 6 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uMXSp8qFKzck"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment