Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save radekosmulski/57cac4d5eab2934ec0a93d2556d46beb to your computer and use it in GitHub Desktop.
Save radekosmulski/57cac4d5eab2934ec0a93d2556d46beb to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c35b0b87-33c7-4935-a38e-f94dd22f15f7",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.ensemble import HistGradientBoostingRegressor\n",
"from sklearn.model_selection import train_test_split, ParameterGrid"
]
},
{
"cell_type": "markdown",
"id": "a147955f",
"metadata": {},
"source": [
"A `HistGradientBoostingRegressor` trained on only a subset of columns (all numerical) with hyperparameter tuning.\n",
"\n",
"We begin by loading in the data and follow with deterministically sampling a train-val-test split (70-10-20).\n",
"\n",
"Subsequently, we find a promising set of parameters using the validation set and proceed to train and evaluate our model."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8daab967",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>UNITID</th>\n",
" <th>school_name</th>\n",
" <th>city</th>\n",
" <th>state</th>\n",
" <th>zip</th>\n",
" <th>school_webpage</th>\n",
" <th>latitude</th>\n",
" <th>longitude</th>\n",
" <th>admission_rate</th>\n",
" <th>sat_verbal_midrange</th>\n",
" <th>...</th>\n",
" <th>carnegie_undergraduate</th>\n",
" <th>carnegie_size</th>\n",
" <th>religious_affiliation</th>\n",
" <th>percent_female</th>\n",
" <th>agege24</th>\n",
" <th>faminc</th>\n",
" <th>mean_earnings_6_years</th>\n",
" <th>median_earnings_6_years</th>\n",
" <th>mean_earnings_10_years</th>\n",
" <th>median_earnings_10_years</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>100654</td>\n",
" <td>'Alabama A &amp; M University'</td>\n",
" <td>Normal</td>\n",
" <td>AL</td>\n",
" <td>35762</td>\n",
" <td>www.aamu.edu/</td>\n",
" <td>34.7834</td>\n",
" <td>-86.5685</td>\n",
" <td>0.8989</td>\n",
" <td>410.0</td>\n",
" <td>...</td>\n",
" <td>'Full-time four-year inclusive'</td>\n",
" <td>'Medium 4-year highly residential (3000 to 9999)'</td>\n",
" <td>?</td>\n",
" <td>0.52999997138977</td>\n",
" <td>0.07999999821186</td>\n",
" <td>40211.22</td>\n",
" <td>26100.0</td>\n",
" <td>22800.0</td>\n",
" <td>35300.0</td>\n",
" <td>31400.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>100663</td>\n",
" <td>'University of Alabama at Birmingham'</td>\n",
" <td>Birmingham</td>\n",
" <td>AL</td>\n",
" <td>35294-0110</td>\n",
" <td>www.uab.edu</td>\n",
" <td>33.5022</td>\n",
" <td>-86.8092</td>\n",
" <td>0.8673</td>\n",
" <td>580.0</td>\n",
" <td>...</td>\n",
" <td>'Medium full-time four-year selective higher t...</td>\n",
" <td>'Large 4-year primarily nonresidential (over 9...</td>\n",
" <td>?</td>\n",
" <td>0.64999997615814</td>\n",
" <td>0.25999999046325</td>\n",
" <td>49894.65</td>\n",
" <td>37400.0</td>\n",
" <td>33200.0</td>\n",
" <td>46300.0</td>\n",
" <td>40300.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>100690</td>\n",
" <td>'Amridge University'</td>\n",
" <td>Montgomery</td>\n",
" <td>AL</td>\n",
" <td>36117-3553</td>\n",
" <td>www.amridgeuniversity.edu</td>\n",
" <td>32.3626</td>\n",
" <td>-86.17399999999999</td>\n",
" <td>?</td>\n",
" <td>?</td>\n",
" <td>...</td>\n",
" <td>'Medium full-time four-year inclusivestudents ...</td>\n",
" <td>'Very small 4-year primarily nonresidential (l...</td>\n",
" <td>'Churches of Christ'</td>\n",
" <td>0.50999999046325</td>\n",
" <td>0.82999998331069</td>\n",
" <td>38712.18</td>\n",
" <td>38500.0</td>\n",
" <td>32800.0</td>\n",
" <td>42100.0</td>\n",
" <td>38100.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>100706</td>\n",
" <td>'University of Alabama in Huntsville'</td>\n",
" <td>Huntsville</td>\n",
" <td>AL</td>\n",
" <td>35899</td>\n",
" <td>www.uah.edu</td>\n",
" <td>34.7228</td>\n",
" <td>-86.6384</td>\n",
" <td>0.8062</td>\n",
" <td>575.0</td>\n",
" <td>...</td>\n",
" <td>'Medium full-time four-year selective higher t...</td>\n",
" <td>'Medium 4-year primarily nonresidential (3000 ...</td>\n",
" <td>?</td>\n",
" <td>0.55000001192092</td>\n",
" <td>0.28999999165534</td>\n",
" <td>54155.4</td>\n",
" <td>39300.0</td>\n",
" <td>36700.0</td>\n",
" <td>52700.0</td>\n",
" <td>46600.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100724</td>\n",
" <td>'Alabama State University'</td>\n",
" <td>Montgomery</td>\n",
" <td>AL</td>\n",
" <td>36104-0271</td>\n",
" <td>www.alasu.edu/email/index.aspx</td>\n",
" <td>32.3643</td>\n",
" <td>-86.2957</td>\n",
" <td>0.5125</td>\n",
" <td>430.0</td>\n",
" <td>...</td>\n",
" <td>'Full-time four-year inclusive'</td>\n",
" <td>'Medium 4-year primarily residential (3000 to ...</td>\n",
" <td>?</td>\n",
" <td>0.56999999284744</td>\n",
" <td>0.10999999940395</td>\n",
" <td>31846.99</td>\n",
" <td>21200.0</td>\n",
" <td>19300.0</td>\n",
" <td>30700.0</td>\n",
" <td>27800.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 48 columns</p>\n",
"</div>"
],
"text/plain": [
" UNITID school_name city state \\\n",
"0 100654 'Alabama A & M University' Normal AL \n",
"1 100663 'University of Alabama at Birmingham' Birmingham AL \n",
"2 100690 'Amridge University' Montgomery AL \n",
"3 100706 'University of Alabama in Huntsville' Huntsville AL \n",
"4 100724 'Alabama State University' Montgomery AL \n",
"\n",
" zip school_webpage latitude longitude \\\n",
"0 35762 www.aamu.edu/ 34.7834 -86.5685 \n",
"1 35294-0110 www.uab.edu 33.5022 -86.8092 \n",
"2 36117-3553 www.amridgeuniversity.edu 32.3626 -86.17399999999999 \n",
"3 35899 www.uah.edu 34.7228 -86.6384 \n",
"4 36104-0271 www.alasu.edu/email/index.aspx 32.3643 -86.2957 \n",
"\n",
" admission_rate sat_verbal_midrange ... \\\n",
"0 0.8989 410.0 ... \n",
"1 0.8673 580.0 ... \n",
"2 ? ? ... \n",
"3 0.8062 575.0 ... \n",
"4 0.5125 430.0 ... \n",
"\n",
" carnegie_undergraduate \\\n",
"0 'Full-time four-year inclusive' \n",
"1 'Medium full-time four-year selective higher t... \n",
"2 'Medium full-time four-year inclusivestudents ... \n",
"3 'Medium full-time four-year selective higher t... \n",
"4 'Full-time four-year inclusive' \n",
"\n",
" carnegie_size religious_affiliation \\\n",
"0 'Medium 4-year highly residential (3000 to 9999)' ? \n",
"1 'Large 4-year primarily nonresidential (over 9... ? \n",
"2 'Very small 4-year primarily nonresidential (l... 'Churches of Christ' \n",
"3 'Medium 4-year primarily nonresidential (3000 ... ? \n",
"4 'Medium 4-year primarily residential (3000 to ... ? \n",
"\n",
" percent_female agege24 faminc mean_earnings_6_years \\\n",
"0 0.52999997138977 0.07999999821186 40211.22 26100.0 \n",
"1 0.64999997615814 0.25999999046325 49894.65 37400.0 \n",
"2 0.50999999046325 0.82999998331069 38712.18 38500.0 \n",
"3 0.55000001192092 0.28999999165534 54155.4 39300.0 \n",
"4 0.56999999284744 0.10999999940395 31846.99 21200.0 \n",
"\n",
" median_earnings_6_years mean_earnings_10_years median_earnings_10_years \n",
"0 22800.0 35300.0 31400.0 \n",
"1 33200.0 46300.0 40300.0 \n",
"2 32800.0 42100.0 38100.0 \n",
"3 36700.0 52700.0 46600.0 \n",
"4 19300.0 30700.0 27800.0 \n",
"\n",
"[5 rows x 48 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"colleges = pd.read_csv('../input/colleges/colleges.csv')\n",
"colleges.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e57e1d1d",
"metadata": {},
"outputs": [],
"source": [
"numeric_columns = [\n",
" 'undergrad_size',\n",
" 'spend_per_student',\n",
" 'admission_rate',\n",
" 'mean_earnings_6_years',\n",
" 'median_earnings_6_years',\n",
" 'sat_math_midrange',\n",
" 'sat_verbal_midrange',\n",
" 'act_math_midrange',\n",
" 'act_writing_midrange',\n",
" 'latitude',\n",
" 'longitude',\n",
" 'completion_rate',\n",
" 'tuition_(out_of_state)',\n",
" 'tuition_(instate)',\n",
" 'percent_white',\n",
" 'percent_black',\n",
" 'percent_hispanic',\n",
" 'percent_asian',\n",
" 'median_earnings_10_years',\n",
" 'sat_total_average',\n",
" 'completion_rate',\n",
" 'faculty_salary',\n",
" 'percent_female',\n",
" 'percent_part_time',\n",
" 'agege24',\n",
" 'faminc',\n",
" 'average_cost_program_year',\n",
" 'average_cost_academic_year']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d65494d",
"metadata": {},
"outputs": [],
"source": [
"for col in numeric_columns:\n",
" colleges[col] = pd.to_numeric(colleges[col], errors='coerce')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8028014c",
"metadata": {},
"outputs": [],
"source": [
"targets = colleges['percent_pell_grant'].values"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e33b607a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((7063, 28), (7063,))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"colleges[numeric_columns].values.shape, targets.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "739c7b47",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"seed: 0 test set error: 22.171433294089248 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 700, 'max_leaf_nodes': 21, 'min_samples_leaf': 30}\n",
"seed: 1 test set error: 21.351643043646817 {'learning_rate': 0.075, 'max_depth': 4, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 2 test set error: 20.613522657178827 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 30}\n",
"seed: 3 test set error: 19.90627888191498 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 700, 'max_leaf_nodes': 21, 'min_samples_leaf': 27}\n",
"seed: 4 test set error: 21.06215235443008 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 650, 'max_leaf_nodes': 24, 'min_samples_leaf': 27}\n",
"seed: 5 test set error: 21.571807351331653 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 700, 'max_leaf_nodes': 24, 'min_samples_leaf': 33}\n",
"seed: 6 test set error: 20.603285022715113 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 7 test set error: 22.72010604593044 {'learning_rate': 0.075, 'max_depth': 5, 'max_iter': 650, 'max_leaf_nodes': 21, 'min_samples_leaf': 27}\n",
"seed: 8 test set error: 21.077589895483865 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 30}\n",
"seed: 9 test set error: 18.8056142569569 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"CPU times: user 15h 11min 36s, sys: 1min 43s, total: 15h 13min 19s\n",
"Wall time: 38min 22s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"for seed in range(10):\n",
" np.random.seed(seed)\n",
" \n",
" X_train, X_val_and_test, y_train, y_val_and_test = train_test_split(colleges[numeric_columns].values, targets, test_size=0.3)\n",
"\n",
" X_val, X_test, y_val, y_test = train_test_split(X_val_and_test, y_val_and_test, test_size=0.66)\n",
"\n",
" param_grid = ParameterGrid({\n",
" 'max_depth': [4, 5, 6],\n",
" 'max_iter': [650, 700, 750],\n",
" 'learning_rate': [0.7, 0.075, 0.8],\n",
" 'min_samples_leaf': [27, 30, 33], \n",
" 'max_leaf_nodes': [18, 21, 24]}\n",
" )\n",
"\n",
"\n",
" smallest_error = np.inf\n",
" best_params = None\n",
" for params in param_grid:\n",
" clf = HistGradientBoostingRegressor(**params, validation_fraction=None)\n",
" clf.fit(X_train, y_train)\n",
" preds = clf.predict(X_val)\n",
" error = mean_squared_error(y_val, preds, squared=True)\n",
" if error < smallest_error:\n",
" smallest_error = error\n",
" best_params = params\n",
"\n",
" clf = HistGradientBoostingRegressor(**best_params, validation_fraction=None)\n",
" clf.fit(X_train, y_train)\n",
" preds = clf.predict(X_test)\n",
" test_set_error = mean_squared_error(y_test, preds, squared=True)\n",
"\n",
" print(\"seed:\", seed, \"test set error:\", test_set_error*1000, best_params)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c389201",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"seed: 0 test set error: 22.950182801521365 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 1 test set error: 20.63139959734668 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 2 test set error: 20.074620809704737 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 3 test set error: 20.08524739678061 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 4 test set error: 20.71665020455993 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 5 test set error: 22.02020075905716 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 6 test set error: 20.99264937391348 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 7 test set error: 23.186960525595023 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 8 test set error: 20.889422557128814 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"seed: 9 test set error: 19.339077384591423 {'learning_rate': 0.075, 'max_depth': 6, 'max_iter': 650, 'max_leaf_nodes': 18, 'min_samples_leaf': 33}\n",
"CPU times: user 2min 39s, sys: 244 ms, total: 2min 39s\n",
"Wall time: 6.74 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"for seed in range(10):\n",
" np.random.seed(seed)\n",
" \n",
" X_train, X_val_and_test, y_train, y_val_and_test = train_test_split(colleges[numeric_columns].values, targets, test_size=0.3)\n",
"\n",
" X_val, X_test, y_val, y_test = train_test_split(X_val_and_test, y_val_and_test, test_size=0.66)\n",
"\n",
" clf = HistGradientBoostingRegressor(max_depth=5, max_iter=400, learning_rate=0.05, min_samples_leaf=15, max_leaf_nodes=31, validation_fraction=None)\n",
" clf.fit(X_train, y_train)\n",
" preds = clf.predict(X_test)\n",
" test_set_error = mean_squared_error(y_test, preds, squared=True)\n",
"\n",
" print(\"seed:\", seed, \"test set error:\", test_set_error*1000, best_params)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment