Created
May 10, 2022 10:36
-
-
Save chetanambi/980dff2636c54086985c3b13706a5b2c to your computer and use it in GitHub Desktop.
OneHotEncoder_vs_get_dummies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "e905e1ce", | |
"metadata": {}, | |
"source": [ | |
"# OneHotEncoder Vs get_dummies" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "19b73275", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"from sklearn.pipeline import Pipeline\n", | |
"from sklearn.compose import ColumnTransformer\n", | |
"from sklearn.preprocessing import OneHotEncoder\n", | |
"from sklearn.preprocessing import MinMaxScaler, StandardScaler\n", | |
"from sklearn.linear_model import LinearRegression\n", | |
"from sklearn.model_selection import train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "5e919c45", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df = sns.load_dataset('tips')\n", | |
"df = df[['total_bill', 'tip', 'day', 'size']]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "8d2caae5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>tip</th>\n", | |
" <th>day</th>\n", | |
" <th>size</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>16.99</td>\n", | |
" <td>1.01</td>\n", | |
" <td>Sun</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>10.34</td>\n", | |
" <td>1.66</td>\n", | |
" <td>Sun</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>21.01</td>\n", | |
" <td>3.50</td>\n", | |
" <td>Sun</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>23.68</td>\n", | |
" <td>3.31</td>\n", | |
" <td>Sun</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>24.59</td>\n", | |
" <td>3.61</td>\n", | |
" <td>Sun</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill tip day size\n", | |
"0 16.99 1.01 Sun 2\n", | |
"1 10.34 1.66 Sun 3\n", | |
"2 21.01 3.50 Sun 3\n", | |
"3 23.68 3.31 Sun 2\n", | |
"4 24.59 3.61 Sun 4" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.head(5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e2244d85", | |
"metadata": {}, | |
"source": [ | |
"## Scikit-learn OneHotEncoder" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "0cecdd2c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = df.drop('tip', axis=1)\n", | |
"y = df['tip']\n", | |
"\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "c19f3095", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ohe = OneHotEncoder(handle_unknown='ignore', sparse=False, dtype='int')\n", | |
"ohe.fit(X_train[['day']])\n", | |
"\n", | |
"def get_ohe(df):\n", | |
" temp_df = pd.DataFrame(data=ohe.transform(df[['day']]), columns=ohe.get_feature_names_out())\n", | |
" df.drop(columns=['day'], axis=1, inplace=True)\n", | |
" df = pd.concat([df.reset_index(drop=True), temp_df], axis=1)\n", | |
" return df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "ebea7235", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train = get_ohe(X_train)\n", | |
"X_test = get_ohe(X_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a4c6a11f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>size</th>\n", | |
" <th>day_Fri</th>\n", | |
" <th>day_Sat</th>\n", | |
" <th>day_Sun</th>\n", | |
" <th>day_Thur</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>26.88</td>\n", | |
" <td>4</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>32.68</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>17.89</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>20.49</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>48.17</td>\n", | |
" <td>6</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill size day_Fri day_Sat day_Sun day_Thur\n", | |
"0 26.88 4 0 0 1 0\n", | |
"1 32.68 2 0 0 0 1\n", | |
"2 17.89 2 0 0 1 0\n", | |
"3 20.49 2 0 1 0 0\n", | |
"4 48.17 6 0 0 1 0" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_train.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "9c32e5fd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>size</th>\n", | |
" <th>day_Fri</th>\n", | |
" <th>day_Sat</th>\n", | |
" <th>day_Sun</th>\n", | |
" <th>day_Thur</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>4</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill size day_Fri day_Sat day_Sun day_Thur\n", | |
"0 25 2 0 0 1 0\n", | |
"1 45 4 0 0 0 0" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
"get_ohe(X_test_new)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "148f3359", | |
"metadata": {}, | |
"source": [ | |
"### One hot encoding in pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "49542fbd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = df.drop('tip', axis=1)\n", | |
"y = df['tip']\n", | |
"\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "1e620741", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"numeric_preprocessor = Pipeline(steps=[\n", | |
" (\"scaler\", MinMaxScaler()) \n", | |
"])\n", | |
"\n", | |
"categorical_preprocessor = Pipeline(steps=[ \n", | |
" (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\")) \n", | |
"])\n", | |
"\n", | |
"preprocessor = ColumnTransformer([\n", | |
" (\"categorical\", categorical_preprocessor, [\"day\"]),\n", | |
" (\"numerical\", numeric_preprocessor, [\"total_bill\", \"size\"])\n", | |
"])\n", | |
"\n", | |
"pipe = Pipeline(steps=[\n", | |
" (\"preprocessor\", preprocessor), \n", | |
" (\"classifier\", LinearRegression())\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "3f9e210c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5706168878130049" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.fit(X_train, y_train)\n", | |
"pipe.score(X_test, y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "dea80c6d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>day</th>\n", | |
" <th>size</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>Sun</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>Mon</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill day size\n", | |
"0 25 Sun 2\n", | |
"1 45 Mon 4" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
"X_test_new" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "7770b3a2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([3.27325822, 5.436413 ])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.predict(X_test_new)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ddb6ac19", | |
"metadata": {}, | |
"source": [ | |
"## Pandas get_dummies " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "26e6c5eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = df.drop('tip', axis=1)\n", | |
"y = df['tip']\n", | |
"\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "248ff42a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train = pd.get_dummies(X_train, columns=['day'])\n", | |
"X_test = pd.get_dummies(X_test, columns=['day'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "0fd105d0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>size</th>\n", | |
" <th>day_Thur</th>\n", | |
" <th>day_Fri</th>\n", | |
" <th>day_Sat</th>\n", | |
" <th>day_Sun</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>26.88</td>\n", | |
" <td>4</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>83</th>\n", | |
" <td>32.68</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>176</th>\n", | |
" <td>17.89</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>106</th>\n", | |
" <td>20.49</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>156</th>\n", | |
" <td>48.17</td>\n", | |
" <td>6</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill size day_Thur day_Fri day_Sat day_Sun\n", | |
"7 26.88 4 0 0 0 1\n", | |
"83 32.68 2 1 0 0 0\n", | |
"176 17.89 2 0 0 0 1\n", | |
"106 20.49 2 0 0 1 0\n", | |
"156 48.17 6 0 0 0 1" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_train.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "0c048ce6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cols = X_test.columns.tolist()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "015e66a3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>size</th>\n", | |
" <th>day_Mon</th>\n", | |
" <th>day_Sun</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill size day_Mon day_Sun\n", | |
"0 25 2 0 1\n", | |
"1 45 4 1 0" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
"pd.get_dummies(X_test_new, columns=['day']).head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "75cd1330", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>day</th>\n", | |
" <th>size</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>Sun</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>Mon</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill day size\n", | |
"0 25 Sun 2\n", | |
"1 45 Mon 4" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "6abea7c5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>size</th>\n", | |
" <th>day_Thur</th>\n", | |
" <th>day_Fri</th>\n", | |
" <th>day_Sat</th>\n", | |
" <th>day_Sun</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>2</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>4</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill size day_Thur day_Fri day_Sat day_Sun\n", | |
"0 25 2 0.0 0.0 0.0 1\n", | |
"1 45 4 0.0 0.0 0.0 0" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new = pd.get_dummies(X_test_new, columns=['day'])\n", | |
"X_test_new.reindex(columns=cols).fillna(0) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8c3be6c6", | |
"metadata": {}, | |
"source": [ | |
"### get_dummies in Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "72a6c0fc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = df.drop('tip', axis=1)\n", | |
"y = df['tip']\n", | |
"\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "391df006", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cols = [\"total_bill\", \"size\", \"day_Fri\", \"day_Sat\", \"day_Sun\", \"day_Thur\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "fafc567b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.base import BaseEstimator, TransformerMixin\n", | |
"\n", | |
"class PreprocessorTransformer(BaseEstimator, TransformerMixin):\n", | |
" def __init__(self, cols):\n", | |
" self.cols = cols\n", | |
" \n", | |
" def fit(self, X, y = None):\n", | |
" return self\n", | |
" \n", | |
" def transform(self, X, y = None):\n", | |
" X = pd.get_dummies(X, columns=['day'])\n", | |
" X = X.reindex(columns=self.cols).fillna(0) \n", | |
" return X[self.cols]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "ae0ec1a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"preprocessor = Pipeline(steps=[ \n", | |
" (\"preprocessor\", PreprocessorTransformer(cols)) \n", | |
"])\n", | |
"\n", | |
"pipe = Pipeline(steps=[\n", | |
" (\"preprocessor\", preprocessor), \n", | |
" (\"classifier\", LinearRegression())\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "4e363485", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pipe.fit(X_train, y_train);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "612e202f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5706168878130053" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.score(X_test, y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "05518bcc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>total_bill</th>\n", | |
" <th>day</th>\n", | |
" <th>size</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>Sun</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>45</td>\n", | |
" <td>Mon</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" total_bill day size\n", | |
"0 25 Sun 2\n", | |
"1 45 Mon 4" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
"X_test_new" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "e475f927", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([3.27325822, 5.436413 ])" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.predict(X_test_new)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a49f3eaf", | |
"metadata": {}, | |
"source": [ | |
"## Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "2930f647", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = {\"Airline\": [\"American Airlines\", \"Delta Air Lines\", \"United Airlines\"]}\n", | |
"df = pd.DataFrame(data=data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "f0fd2cd9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Airline_American Airlines</th>\n", | |
" <th>Airline_Delta Air Lines</th>\n", | |
" <th>Airline_United Airlines</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Airline_American Airlines Airline_Delta Air Lines Airline_United Airlines\n", | |
"0 1 0 0\n", | |
"1 0 1 0\n", | |
"2 0 0 1" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pd.get_dummies(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "1f34ab08", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead tr th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th></th>\n", | |
" <th>American Airlines</th>\n", | |
" <th>Delta Air Lines</th>\n", | |
" <th>United Airlines</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" American Airlines Delta Air Lines United Airlines\n", | |
"0 1 0 0\n", | |
"1 0 1 0\n", | |
"2 0 0 1" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ohe = OneHotEncoder(sparse=False)\n", | |
"pd.DataFrame(ohe.fit_transform(df), columns=ohe.categories_, dtype='int8')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment