Last active
February 13, 2020 20:57
-
-
Save gautam-e/b1132c92c690287ee9ea613c3add56bd to your computer and use it in GitHub Desktop.
proc_df for fastai_v1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "# Use fastai_v1 for (pre-)processing of tablular data" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "... or in other words, a fastai v1 equivalent of doing `proc_df()` so that it can be used for e.g. scikit-learn models like Random Forest etc." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "from fastai.tabular import * \nfrom sklearn.model_selection import train_test_split\nfrom sklearn.ensemble import RandomForestClassifier", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Get the data" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "path = untar_data(URLs.ADULT_SAMPLE)\npath", | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "PosixPath('/home/gautam/.fastai/data/adult_sample')" | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "df = pd.read_csv(path/'adult.csv')\ndf.head()", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>age</th>\n <th>workclass</th>\n <th>fnlwgt</th>\n <th>education</th>\n <th>education-num</th>\n <th>marital-status</th>\n <th>occupation</th>\n <th>relationship</th>\n <th>race</th>\n <th>sex</th>\n <th>capital-gain</th>\n <th>capital-loss</th>\n <th>hours-per-week</th>\n <th>native-country</th>\n <th>salary</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>49</td>\n <td>Private</td>\n <td>101320</td>\n <td>Assoc-acdm</td>\n <td>12.0</td>\n <td>Married-civ-spouse</td>\n <td>NaN</td>\n <td>Wife</td>\n <td>White</td>\n <td>Female</td>\n <td>0</td>\n <td>1902</td>\n <td>40</td>\n <td>United-States</td>\n <td>>=50k</td>\n </tr>\n <tr>\n <th>1</th>\n <td>44</td>\n <td>Private</td>\n <td>236746</td>\n <td>Masters</td>\n <td>14.0</td>\n <td>Divorced</td>\n <td>Exec-managerial</td>\n <td>Not-in-family</td>\n <td>White</td>\n <td>Male</td>\n <td>10520</td>\n <td>0</td>\n <td>45</td>\n <td>United-States</td>\n <td>>=50k</td>\n </tr>\n <tr>\n <th>2</th>\n <td>38</td>\n <td>Private</td>\n <td>96185</td>\n <td>HS-grad</td>\n <td>NaN</td>\n <td>Divorced</td>\n <td>NaN</td>\n <td>Unmarried</td>\n <td>Black</td>\n <td>Female</td>\n <td>0</td>\n <td>0</td>\n <td>32</td>\n <td>United-States</td>\n <td><50k</td>\n </tr>\n <tr>\n <th>3</th>\n <td>38</td>\n <td>Self-emp-inc</td>\n <td>112847</td>\n <td>Prof-school</td>\n <td>15.0</td>\n <td>Married-civ-spouse</td>\n <td>Prof-specialty</td>\n <td>Husband</td>\n <td>Asian-Pac-Islander</td>\n <td>Male</td>\n <td>0</td>\n <td>0</td>\n <td>40</td>\n <td>United-States</td>\n <td>>=50k</td>\n </tr>\n <tr>\n <th>4</th>\n <td>42</td>\n <td>Self-emp-not-inc</td>\n <td>82297</td>\n <td>7th-8th</td>\n <td>NaN</td>\n <td>Married-civ-spouse</td>\n <td>Other-service</td>\n <td>Wife</td>\n <td>Black</td>\n <td>Female</td>\n <td>0</td>\n <td>0</td>\n <td>50</td>\n <td>United-States</td>\n <td><50k</td>\n </tr>\n </tbody>\n</table>\n</div>", | |
"text/plain": " age workclass fnlwgt education education-num \\\n0 49 Private 101320 Assoc-acdm 12.0 \n1 44 Private 236746 Masters 14.0 \n2 38 Private 96185 HS-grad NaN \n3 38 Self-emp-inc 112847 Prof-school 15.0 \n4 42 Self-emp-not-inc 82297 7th-8th NaN \n\n marital-status occupation relationship race \\\n0 Married-civ-spouse NaN Wife White \n1 Divorced Exec-managerial Not-in-family White \n2 Divorced NaN Unmarried Black \n3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n4 Married-civ-spouse Other-service Wife Black \n\n sex capital-gain capital-loss hours-per-week native-country salary \n0 Female 0 1902 40 United-States >=50k \n1 Male 10520 0 45 United-States >=50k \n2 Female 0 0 32 United-States <50k \n3 Male 0 0 40 United-States >=50k \n4 Female 0 0 50 United-States <50k " | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "procs = [FillMissing, Categorify, Normalize]\nvalid_idx = range(len(df)-2000, len(df))\ndep_var = 'salary'\ncat_names = ['workclass', 'education', 'marital-status', 'occupation', \n 'relationship', 'race', 'sex', 'native-country']\n\ndata = TabularDataBunch.from_df(path, df, dep_var, valid_idx=valid_idx, procs=procs, cat_names=cat_names)\ndata.show_batch()", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": "<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th>workclass</th>\n <th>education</th>\n <th>marital-status</th>\n <th>occupation</th>\n <th>relationship</th>\n <th>race</th>\n <th>sex</th>\n <th>native-country</th>\n <th>education-num_na</th>\n <th>capital-loss</th>\n <th>education-num</th>\n <th>fnlwgt</th>\n <th>age</th>\n <th>capital-gain</th>\n <th>hours-per-week</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>Private</td>\n <td>Some-college</td>\n <td>Never-married</td>\n <td>Handlers-cleaners</td>\n <td>Not-in-family</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-0.0297</td>\n <td>0.2848</td>\n <td>-1.3638</td>\n <td>-0.1459</td>\n <td>-0.8437</td>\n <td><50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>11th</td>\n <td>Never-married</td>\n <td>Other-service</td>\n <td>Own-child</td>\n <td>White</td>\n <td>Female</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-1.2052</td>\n <td>0.4809</td>\n <td>-1.5102</td>\n <td>-0.1459</td>\n <td>-1.6516</td>\n <td><50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>HS-grad</td>\n <td>Married-civ-spouse</td>\n <td>Sales</td>\n <td>Husband</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-0.4216</td>\n <td>0.0778</td>\n <td>-0.9244</td>\n <td>-0.1459</td>\n <td>1.9840</td>\n <td><50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>Doctorate</td>\n <td>Married-civ-spouse</td>\n <td>Prof-specialty</td>\n <td>Husband</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>2.3212</td>\n <td>0.5144</td>\n <td>1.7118</td>\n <td>1.8720</td>\n <td>0.7721</td>\n <td>>=50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>Bachelors</td>\n <td>Divorced</td>\n <td>Prof-specialty</td>\n <td>Unmarried</td>\n <td>White</td>\n <td>Female</td>\n <td>Germany</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>1.1457</td>\n <td>-0.5488</td>\n <td>-0.4118</td>\n <td>-0.0232</td>\n <td>-0.0358</td>\n <td><50k</td>\n </tr>\n </tbody>\n</table>", | |
"text/plain": "<IPython.core.display.HTML object>" | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "print(data.train_ds.cont_names) # `cont_names` defaults to: set(df)-set(cat_names)-{dep_var}", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "['capital-loss', 'education-num', 'fnlwgt', 'age', 'capital-gain', 'hours-per-week']\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "data.train_ds.x.codes.shape, data.train_ds.x.conts.shape, data.train_ds.y.items.shape", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "((30561, 9), (30561, 6), (30561,))" | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "X = np.concatenate((data.train_ds.x.codes, data.train_ds.x.conts), axis=1)\ny = data.train_ds.y.items\nX.shape, y.shape", | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "((30561, 15), (30561,))" | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Split into training and validation sets" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=None, random_state=42)\nX_train.shape, X_valid.shape, X_train.shape[0] + X_valid.shape[0]", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "((22920, 15), (7641, 15), 30561)" | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Train RandomForest model" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "m = RandomForestClassifier(n_jobs=-1, n_estimators=10)\nm.fit(X_train, y_train)\nm.score(X_train,y_train), m.score(X_valid,y_valid)", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(0.9872600349040139, 0.8564324041355843)" | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
} | |
], | |
"metadata": { | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/b1132c92c690287ee9ea613c3add56bd" | |
}, | |
"gist": { | |
"id": "b1132c92c690287ee9ea613c3add56bd", | |
"data": { | |
"description": "proc_df for fastai_v1", | |
"public": true | |
} | |
}, | |
"kernelspec": { | |
"name": "conda-env-fastai-3.6-py", | |
"display_name": "Python [conda env:fastai-3.6] *", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.9", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment