Skip to content

Instantly share code, notes, and snippets.

@jakirkham
Created June 16, 2022 04:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jakirkham/a597ca5929679f6ccecdb1f03bab62f1 to your computer and use it in GitHub Desktop.
Save jakirkham/a597ca5929679f6ccecdb1f03bab62f1 to your computer and use it in GitHub Desktop.
Notebook loading Higgs dataset for use with Dask + XGBoost
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "19f93a66-538c-4f3e-9e74-079a556716e9",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"DASK_DISTRIBUTED__SCHEDULER__WORK_STEALING\"] = \"False\" # needs to be added to dask-scheduler\n",
"\n",
"from functools import partial\n",
"from itertools import starmap\n",
"from operator import attrgetter, getitem\n",
"from math import ceil\n",
"\n",
"from tlz import sliding_window\n",
"\n",
"from dask_cuda import LocalCUDACluster\n",
"from dask.distributed import Client, wait\n",
"from dask import delayed\n",
"import dask_cudf\n",
"import cudf\n",
"import distributed\n",
"import xgboost as xgb\n",
"import time\n",
"from dask.utils import stringify\n",
"\n",
"from sklearn.metrics import mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1a1e2ff-be48-44a6-8767-265d417de1b5",
"metadata": {},
"outputs": [],
"source": [
"def reproducible_persist_per_worker(df, client):\n",
" # Query workers\n",
" n_workers = len(client.cluster.workers)\n",
" workers = map(attrgetter(\"worker_address\"), client.cluster.workers.values())\n",
"\n",
" # Slice data into roughly equal partitions\n",
" subpartition_size = ceil(df.npartitions / n_workers)\n",
" subpartition_divisions = range(0, df.npartitions + subpartition_size, subpartition_size)\n",
" subpartition_slices = starmap(slice, sliding_window(2, subpartition_divisions))\n",
" subpartitions = map(partial(getitem, df.partitions), subpartition_slices)\n",
"\n",
" # Persist each subpartition on each worker\n",
" # Rebuild dataframe from persisted subpartitions\n",
" df2 = dask_cudf.concat([sp.persist(workers=w, allow_other_workers=False) for sp, w in zip(subpartitions, workers)])\n",
"\n",
" return df2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed98e366-1450-45d6-9c2f-a3897a36cdae",
"metadata": {},
"outputs": [],
"source": [
"n_workers = 8\n",
"cluster = LocalCUDACluster(n_workers=n_workers)\n",
"client = Client(cluster)\n",
"client.wait_for_workers(n_workers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9dcf799-e12f-4704-b928-70e824536b99",
"metadata": {},
"outputs": [],
"source": [
"fname = 'HIGGS.csv'\n",
"colnames = ['label'] + ['feature-%02d' % i for i in range(1, 29)]\n",
"df = dask_cudf.read_csv(fname, header=None, names=colnames)\n",
"df = reproducible_persist_per_worker2(df, client)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c96374b-9c16-4543-a4eb-61600894f638",
"metadata": {},
"outputs": [],
"source": [
"df_features = df.drop(columns=['label'])\n",
"df_labels = df['label']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9b00774-c794-492c-8b0d-cf2d60f0b97a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dmatrix = xgb.dask.DaskDeviceQuantileDMatrix(client=client,\n",
" data=df_features,\n",
" label=df_labels)\n",
"\n",
"model = xgb.dask.train(client,\n",
" {'verbosity': 0,\n",
" 'tree_method': 'gpu_hist',\n",
" 'seed': 123},\n",
" dtrain=dmatrix,\n",
" num_boost_round=3000,\n",
" evals=[(dmatrix,'dtrain')])\n",
"\n",
"print(\"Final train loss: \", model['history']['dtrain']['rmse'][-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72ec3c6b-3536-49b0-aab7-5128479135ba",
"metadata": {},
"outputs": [],
"source": [
"y_pred = xgb.dask.predict(client, model, df_features).to_frame().compute()\n",
"y_pred = y_pred.rename({0: 'score'}, axis=1)\n",
"\n",
"y_pred.to_parquet(\"result.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2bc2d19e-15b1-4bcc-94ab-f45d08cbd776",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment