Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save cordon-thiago/38aed8ebd06bfa484c177405eb9c974c to your computer and use it in GitHub Desktop.
Save cordon-thiago/38aed8ebd06bfa484c177405eb9c974c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# XGBoost with imbalanced dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split dataset in Train / Test"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(338377, 9)\n",
"(145019, 9)\n",
"(338377,)\n",
"(145019,)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
" \n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=123)\n",
"\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"\n",
"print(y_train.shape)\n",
"print(y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Extreme Gradient Boosting (XGBoost)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import xgboost as xgb\n",
"from sklearn import metrics"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/thiago/.local/lib/python3.7/site-packages/xgboost/core.py:587: FutureWarning: Series.base is deprecated and will be removed in a future version\n",
" if getattr(data, 'base', None) is not None and \\\n"
]
}
],
"source": [
"D_train = xgb.DMatrix(X_train, label=y_train)\n",
"D_test = xgb.DMatrix(X_test, label=y_test)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
" colsample_bynode=1, colsample_bytree=1, gamma=0,\n",
" learning_rate=0.1, max_delta_step=0, max_depth=5,\n",
" min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,\n",
" nthread=None, objective='reg:logistic', random_state=123,\n",
" reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,\n",
" silent=None, subsample=1, verbosity=1)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# XGBoost Classifier Train\n",
"\n",
"xgb_classif = xgb.XGBClassifier(\n",
" max_depth=5, \n",
" objective='reg:logistic', \n",
" random_state=123)\n",
"\n",
"xgb_classif.fit(X_train,y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"pickle.dump(xgb_classif, open('models/xgb-classifier-withImbalance.sav', 'wb'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment