Skip to content

Instantly share code, notes, and snippets.

@RahulDas-dev
Last active February 23, 2024 12:11
Show Gist options
  • Save RahulDas-dev/3fda38233a3fa9ace94663ed4cdd2be5 to your computer and use it in GitHub Desktop.
Save RahulDas-dev/3fda38233a3fa9ace94663ed4cdd2be5 to your computer and use it in GitHub Desktop.
Scikit-Learn Pipe Line Building + Optuna Search cv + Joblib Memory
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "056ec550",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.compose import ColumnTransformer, make_column_selector\n",
"# from mlxtend.feature_selection import ColumnSelector\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder, FunctionTransformer, OrdinalEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from joblib import Memory\n"
]
},
{
"cell_type": "markdown",
"id": "c298be18",
"metadata": {},
"source": [
"## DataLoad "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cb238cf7",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(48842, 14) (48842,)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capitalgain</th>\n",
" <th>capitalloss</th>\n",
" <th>hoursperweek</th>\n",
" <th>native-country</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>State-gov</td>\n",
" <td>77516</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Never-married</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>United-States</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>83311</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>United-States</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>Private</td>\n",
" <td>215646</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Divorced</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>United-States</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>Private</td>\n",
" <td>234721</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>United-States</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>Private</td>\n",
" <td>338409</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Wife</td>\n",
" <td>Black</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Cuba</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education-num marital-status \\\n",
"0 2 State-gov 77516 Bachelors 13 Never-married \n",
"1 3 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse \n",
"2 2 Private 215646 HS-grad 9 Divorced \n",
"3 3 Private 234721 11th 7 Married-civ-spouse \n",
"4 1 Private 338409 Bachelors 13 Married-civ-spouse \n",
"\n",
" occupation relationship race sex capitalgain capitalloss \\\n",
"0 Adm-clerical Not-in-family White Male 1 0 \n",
"1 Exec-managerial Husband White Male 0 0 \n",
"2 Handlers-cleaners Not-in-family White Male 0 0 \n",
"3 Handlers-cleaners Husband Black Male 0 0 \n",
"4 Prof-specialty Wife Black Female 0 0 \n",
"\n",
" hoursperweek native-country \n",
"0 2 United-States \n",
"1 0 United-States \n",
"2 2 United-States \n",
"3 2 United-States \n",
"4 2 Cuba "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fetch_adult_data():\n",
" from sklearn.datasets import fetch_openml\n",
" \n",
" from sklearn.datasets import fetch_openml\n",
" openml_ds = fetch_openml(data_id=179, as_frame=True, parser='pandas')\n",
" dataset = openml_ds['frame']\n",
" return dataset\n",
"\n",
"dataset = fetch_adult_data()\n",
"\n",
"target = dataset.pop('class')\n",
"\n",
"print(dataset.shape, target.shape)\n",
"dataset.head()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cff016e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 48842 entries, 0 to 48841\n",
"Data columns (total 14 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 age 48842 non-null category\n",
" 1 workclass 46043 non-null category\n",
" 2 fnlwgt 48842 non-null int64 \n",
" 3 education 48842 non-null category\n",
" 4 education-num 48842 non-null int64 \n",
" 5 marital-status 48842 non-null category\n",
" 6 occupation 46033 non-null category\n",
" 7 relationship 48842 non-null category\n",
" 8 race 48842 non-null category\n",
" 9 sex 48842 non-null category\n",
" 10 capitalgain 48842 non-null category\n",
" 11 capitalloss 48842 non-null category\n",
" 12 hoursperweek 48842 non-null category\n",
" 13 native-country 47985 non-null category\n",
"dtypes: category(12), int64(2)\n",
"memory usage: 1.3 MB\n"
]
}
],
"source": [
"dataset.info()"
]
},
{
"cell_type": "markdown",
"id": "16cbfa18",
"metadata": {},
"source": [
"## Preprocessor Builder"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "233f6925",
"metadata": {},
"outputs": [],
"source": [
"def bool_to_number(x: np.ndarray) -> np.ndarray:\n",
" return np.multiply(x, 1)\n",
"\n",
"BooleanTransformer = FunctionTransformer(bool_to_number, feature_names_out = 'one-to-one')\n",
"\n",
"def build_preprocessor_pipeline(dataset: pd.DataFrame, n_jobs_: int = -1, verbose_: bool = False) -> ColumnTransformer:\n",
" numerical_columns = make_column_selector(dtype_include=[np.number])(dataset)\n",
" categorical_columns = make_column_selector(dtype_include=['category'])(dataset)\n",
" boolean_columns = make_column_selector(dtype_include=['bool'])(dataset)\n",
" \n",
" transformers_ = []\n",
" \n",
" if numerical_columns: \n",
" transformers_.append((\"transformer_n\", SimpleImputer(strategy=\"mean\"), numerical_columns ))\n",
" if categorical_columns: \n",
" transformer_c = Pipeline(\n",
" steps=[\n",
" (\"imputer_c\", SimpleImputer(missing_values=np.nan, strategy='most_frequent')),\n",
" (\"encoder_c\", OrdinalEncoder(handle_unknown=\"use_encoded_value\",\n",
" dtype=np.int8, \n",
" encoded_missing_value=-1,\n",
" unknown_value=-1)\n",
" ),\n",
" ],\n",
" verbose = False,\n",
" memory= None\n",
" )\n",
" transformers_.append((\"transformer_c\", transformer_c, categorical_columns )) \n",
" if boolean_columns: \n",
" transformer_b = Pipeline(\n",
" steps=[(\"to_int\", BooleanTransformer), \n",
" (\"imputer_c\", SimpleImputer(missing_values=np.nan, strategy='most_frequent'))\n",
" ],\n",
" verbose = False,\n",
" memory= None\n",
" ) \n",
" transformers_.append((\"transformer_b\", transformer_b, boolean_columns )) \n",
" \n",
" preprocessor = ColumnTransformer(\n",
" transformers=transformers_,\n",
" n_jobs = n_jobs_,\n",
" remainder='drop',\n",
" verbose_feature_names_out=False,\n",
" verbose=verbose_\n",
" ).set_output(transform='pandas')\n",
"\n",
" return preprocessor"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8ea499d5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ColumnTransformer(n_jobs=-1,\n",
" transformers=[(&#x27;transformer_n&#x27;, SimpleImputer(),\n",
" [&#x27;fnlwgt&#x27;, &#x27;education-num&#x27;]),\n",
" (&#x27;transformer_c&#x27;,\n",
" Pipeline(steps=[(&#x27;imputer_c&#x27;,\n",
" SimpleImputer(strategy=&#x27;most_frequent&#x27;)),\n",
" (&#x27;encoder_c&#x27;,\n",
" OrdinalEncoder(dtype=&lt;class &#x27;numpy.int8&#x27;&gt;,\n",
" encoded_missing_value=-1,\n",
" handle_unknown=&#x27;use_encoded_value&#x27;,\n",
" unknown_value=-1))]),\n",
" [&#x27;age&#x27;, &#x27;workclass&#x27;, &#x27;education&#x27;,\n",
" &#x27;marital-status&#x27;, &#x27;occupation&#x27;,\n",
" &#x27;relationship&#x27;, &#x27;race&#x27;, &#x27;sex&#x27;, &#x27;capitalgain&#x27;,\n",
" &#x27;capitalloss&#x27;, &#x27;hoursperweek&#x27;,\n",
" &#x27;native-country&#x27;])],\n",
" verbose_feature_names_out=False)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-10\" type=\"checkbox\" ><label for=\"sk-estimator-id-10\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(n_jobs=-1,\n",
" transformers=[(&#x27;transformer_n&#x27;, SimpleImputer(),\n",
" [&#x27;fnlwgt&#x27;, &#x27;education-num&#x27;]),\n",
" (&#x27;transformer_c&#x27;,\n",
" Pipeline(steps=[(&#x27;imputer_c&#x27;,\n",
" SimpleImputer(strategy=&#x27;most_frequent&#x27;)),\n",
" (&#x27;encoder_c&#x27;,\n",
" OrdinalEncoder(dtype=&lt;class &#x27;numpy.int8&#x27;&gt;,\n",
" encoded_missing_value=-1,\n",
" handle_unknown=&#x27;use_encoded_value&#x27;,\n",
" unknown_value=-1))]),\n",
" [&#x27;age&#x27;, &#x27;workclass&#x27;, &#x27;education&#x27;,\n",
" &#x27;marital-status&#x27;, &#x27;occupation&#x27;,\n",
" &#x27;relationship&#x27;, &#x27;race&#x27;, &#x27;sex&#x27;, &#x27;capitalgain&#x27;,\n",
" &#x27;capitalloss&#x27;, &#x27;hoursperweek&#x27;,\n",
" &#x27;native-country&#x27;])],\n",
" verbose_feature_names_out=False)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-11\" type=\"checkbox\" ><label for=\"sk-estimator-id-11\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">transformer_n</label><div class=\"sk-toggleable__content\"><pre>[&#x27;fnlwgt&#x27;, &#x27;education-num&#x27;]</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-12\" type=\"checkbox\" ><label for=\"sk-estimator-id-12\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-13\" type=\"checkbox\" ><label for=\"sk-estimator-id-13\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">transformer_c</label><div class=\"sk-toggleable__content\"><pre>[&#x27;age&#x27;, &#x27;workclass&#x27;, &#x27;education&#x27;, &#x27;marital-status&#x27;, &#x27;occupation&#x27;, &#x27;relationship&#x27;, &#x27;race&#x27;, &#x27;sex&#x27;, &#x27;capitalgain&#x27;, &#x27;capitalloss&#x27;, &#x27;hoursperweek&#x27;, &#x27;native-country&#x27;]</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-14\" type=\"checkbox\" ><label for=\"sk-estimator-id-14\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy=&#x27;most_frequent&#x27;)</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-15\" type=\"checkbox\" ><label for=\"sk-estimator-id-15\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">OrdinalEncoder</label><div class=\"sk-toggleable__content\"><pre>OrdinalEncoder(dtype=&lt;class &#x27;numpy.int8&#x27;&gt;, encoded_missing_value=-1,\n",
" handle_unknown=&#x27;use_encoded_value&#x27;, unknown_value=-1)</pre></div></div></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"ColumnTransformer(n_jobs=-1,\n",
" transformers=[('transformer_n', SimpleImputer(),\n",
" ['fnlwgt', 'education-num']),\n",
" ('transformer_c',\n",
" Pipeline(steps=[('imputer_c',\n",
" SimpleImputer(strategy='most_frequent')),\n",
" ('encoder_c',\n",
" OrdinalEncoder(dtype=<class 'numpy.int8'>,\n",
" encoded_missing_value=-1,\n",
" handle_unknown='use_encoded_value',\n",
" unknown_value=-1))]),\n",
" ['age', 'workclass', 'education',\n",
" 'marital-status', 'occupation',\n",
" 'relationship', 'race', 'sex', 'capitalgain',\n",
" 'capitalloss', 'hoursperweek',\n",
" 'native-country'])],\n",
" verbose_feature_names_out=False)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessor = build_preprocessor_pipeline(dataset)\n",
"preprocessor"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e696eae3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fnlwgt</th>\n",
" <th>education-num</th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>education</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capitalgain</th>\n",
" <th>capitalloss</th>\n",
" <th>hoursperweek</th>\n",
" <th>native-country</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>77516.0</td>\n",
" <td>13.0</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>9</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>83311.0</td>\n",
" <td>13.0</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>215646.0</td>\n",
" <td>9.0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>11</td>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>234721.0</td>\n",
" <td>7.0</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>338409.0</td>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fnlwgt education-num age workclass education marital-status \\\n",
"0 77516.0 13.0 2 6 9 4 \n",
"1 83311.0 13.0 3 5 9 2 \n",
"2 215646.0 9.0 2 3 11 0 \n",
"3 234721.0 7.0 3 3 1 2 \n",
"4 338409.0 13.0 1 3 9 2 \n",
"\n",
" occupation relationship race sex capitalgain capitalloss \\\n",
"0 0 1 4 1 1 0 \n",
"1 3 0 4 1 0 0 \n",
"2 5 1 4 1 0 0 \n",
"3 5 0 2 1 0 0 \n",
"4 9 5 2 0 0 0 \n",
"\n",
" hoursperweek native-country \n",
"0 2 38 \n",
"1 0 38 \n",
"2 2 38 \n",
"3 2 38 \n",
"4 2 4 "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_trf = preprocessor.fit_transform(dataset)\n",
"\n",
"dataset_trf.head()"
]
},
{
"cell_type": "markdown",
"id": "759ffe3f",
"metadata": {},
"source": [
"## Setting Estimator to Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3f86fc65",
"metadata": {},
"outputs": [],
"source": [
"args = {\n",
" \"random_state\": 10,\n",
" \"n_jobs\": -1,\n",
"}\n",
"\n",
"\n",
"model = Pipeline(\n",
" steps=[(\"transformer\", preprocessor ), \n",
" (\"estimator\", RandomForestClassifier(**args))\n",
" ],\n",
" verbose = False,\n",
" memory = None\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a955d486",
"metadata": {},
"source": [
"## Fitting Data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6cadad96",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[I 2024-01-15 22:04:51,540] A new study created in RDB with name: Randomforest Tuner\n",
"C:\\Users\\rdas6\\AppData\\Local\\Temp\\ipykernel_15780\\2964663329.py:38: ExperimentalWarning: OptunaSearchCV is experimental (supported from v0.17.0). The interface can change in the future.\n",
" optuna_search = OptunaSearchCV(model_,\n",
"[I 2024-01-15 22:04:51,579] Searching the best hyperparameters using 48842 samples...\n",
"[I 2024-01-15 22:05:09,718] Trial 1 finished with value: 0.7607182362198229 and parameters: {'estimator__n_estimators': 100, 'estimator__max_depth': 10, 'estimator__min_impurity_decrease': 0.3810441382224819, 'estimator__max_features': 'log2', 'estimator__bootstrap': False}. Best is trial 1 with value: 0.7607182362198229.\n",
"[I 2024-01-15 22:05:11,399] Trial 3 finished with value: 0.8387248805305925 and parameters: {'estimator__n_estimators': 170, 'estimator__max_depth': 6, 'estimator__min_impurity_decrease': 0.0006576437138114963, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': True}. Best is trial 3 with value: 0.8387248805305925.\n",
"[I 2024-01-15 22:05:12,046] Trial 0 finished with value: 0.8497195138074449 and parameters: {'estimator__n_estimators': 150, 'estimator__max_depth': 9, 'estimator__min_impurity_decrease': 3.6984829700086176e-08, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n",
"[I 2024-01-15 22:05:12,490] Trial 4 finished with value: 0.7607182362198229 and parameters: {'estimator__n_estimators': 260, 'estimator__max_depth': 6, 'estimator__min_impurity_decrease': 0.2013327427553822, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n",
"[I 2024-01-15 22:05:12,814] Trial 2 finished with value: 0.8476516205761779 and parameters: {'estimator__n_estimators': 220, 'estimator__max_depth': 8, 'estimator__min_impurity_decrease': 4.4744726010704304e-05, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n",
"[I 2024-01-15 22:05:12,814] Finished hyperparameter search!\n",
"[I 2024-01-15 22:05:12,833] Refitting the estimator using 48842 samples...\n",
"[I 2024-01-15 22:05:14,353] Finished refitting! (elapsed time: 1.520 sec.)\n",
"[Memory(location=C:\\Users\\rdas6\\AppData\\Local\\Temp\\tmp_q1q_jbu\\joblib)]: Flushing completely the cache\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"End 2 End Time - 22.852312326431274 secs\n"
]
}
],
"source": [
"import tempfile\n",
"from sklearn.base import clone\n",
"import joblib\n",
"import time\n",
"\n",
"from optuna import samplers, create_study\n",
"from optuna.distributions import FloatDistribution, IntDistribution, CategoricalDistribution, IntUniformDistribution\n",
"from optuna.integration import OptunaSearchCV\n",
"\n",
"param_distributions = {\n",
" \"estimator__n_estimators\": IntDistribution(10, 300, step=10),\n",
" \"estimator__max_depth\": IntDistribution(1, 11),\n",
" \"estimator__min_impurity_decrease\": FloatDistribution(0.000000001, 0.5, log=True),\n",
" \"estimator__max_features\": FloatDistribution(0.4, 1),\n",
" \"estimator__max_features\": CategoricalDistribution([1.0, \"sqrt\", \"log2\"]),\n",
" \"estimator__bootstrap\": CategoricalDistribution([True, False]),\n",
"}\n",
"\n",
"\n",
"\n",
"storage_string_ = \"sqlite:///./test_2.db\" # optional\n",
"sampler_ = samplers.TPESampler(seed=10)\n",
"study_ = create_study(storage=storage_string_, \n",
" study_name='Randomforest Tuner',\n",
" direction=\"maximize\", \n",
" sampler=sampler_)\n",
"\n",
"\n",
"\n",
"cv_result, best_params, best_model, best_score = None, None, None, None\n",
"try:\n",
" st_time = time.time()\n",
" tempdir = tempfile.TemporaryDirectory()\n",
" model_ = clone(model)\n",
" memory_ = Memory(tempdir.name, verbose=0) ## use for hypermeter tunning,\n",
" model_.memory = memory_ \n",
" model_.verbose = False\n",
" optuna_search = OptunaSearchCV(model_,\n",
" param_distributions,\n",
" cv=5,\n",
" #max_iter=20,\n",
" n_trials = 5,\n",
" n_jobs=-1,\n",
" random_state=10,\n",
" refit=True,\n",
" verbose = 10,\n",
" timeout = 60*60,\n",
" study=study_\n",
" ) \n",
" optuna_search.fit(dataset,target)\n",
"except Exception as err:\n",
" print(err)\n",
"else:\n",
" cv_result = pd.DataFrame().from_dict(optuna_search.cv_results_)\n",
" best_score = optuna_search.best_score_\n",
" best_model = optuna_search.best_estimator_\n",
" best_params = optuna_search.best_params_\n",
" #print(optuna_search.best_params_, optuna_search.best_index_)\n",
" best_model.memory = None\n",
"finally: \n",
" memory_.clear()\n",
" tempdir.cleanup()\n",
" print(f'End 2 End Time - {time.time() - st_time} secs')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "addf1456",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment