Skip to content

Instantly share code, notes, and snippets.

@gautam-e
Last active February 13, 2020 20:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautam-e/b1132c92c690287ee9ea613c3add56bd to your computer and use it in GitHub Desktop.
Save gautam-e/b1132c92c690287ee9ea613c3add56bd to your computer and use it in GitHub Desktop.
proc_df for fastai_v1
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Use fastai_v1 for (pre-)processing of tablular data"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "... or in other words, a fastai v1 equivalent of doing `proc_df()` so that it can be used for e.g. scikit-learn models like Random Forest etc."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai.tabular import * \nfrom sklearn.model_selection import train_test_split\nfrom sklearn.ensemble import RandomForestClassifier",
"execution_count": 1,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Get the data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "path = untar_data(URLs.ADULT_SAMPLE)\npath",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "PosixPath('/home/gautam/.fastai/data/adult_sample')"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df = pd.read_csv(path/'adult.csv')\ndf.head()",
"execution_count": 3,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>age</th>\n <th>workclass</th>\n <th>fnlwgt</th>\n <th>education</th>\n <th>education-num</th>\n <th>marital-status</th>\n <th>occupation</th>\n <th>relationship</th>\n <th>race</th>\n <th>sex</th>\n <th>capital-gain</th>\n <th>capital-loss</th>\n <th>hours-per-week</th>\n <th>native-country</th>\n <th>salary</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>49</td>\n <td>Private</td>\n <td>101320</td>\n <td>Assoc-acdm</td>\n <td>12.0</td>\n <td>Married-civ-spouse</td>\n <td>NaN</td>\n <td>Wife</td>\n <td>White</td>\n <td>Female</td>\n <td>0</td>\n <td>1902</td>\n <td>40</td>\n <td>United-States</td>\n <td>&gt;=50k</td>\n </tr>\n <tr>\n <th>1</th>\n <td>44</td>\n <td>Private</td>\n <td>236746</td>\n <td>Masters</td>\n <td>14.0</td>\n <td>Divorced</td>\n <td>Exec-managerial</td>\n <td>Not-in-family</td>\n <td>White</td>\n <td>Male</td>\n <td>10520</td>\n <td>0</td>\n <td>45</td>\n <td>United-States</td>\n <td>&gt;=50k</td>\n </tr>\n <tr>\n <th>2</th>\n <td>38</td>\n <td>Private</td>\n <td>96185</td>\n <td>HS-grad</td>\n <td>NaN</td>\n <td>Divorced</td>\n <td>NaN</td>\n <td>Unmarried</td>\n <td>Black</td>\n <td>Female</td>\n <td>0</td>\n <td>0</td>\n <td>32</td>\n <td>United-States</td>\n <td>&lt;50k</td>\n </tr>\n <tr>\n <th>3</th>\n <td>38</td>\n <td>Self-emp-inc</td>\n <td>112847</td>\n <td>Prof-school</td>\n <td>15.0</td>\n <td>Married-civ-spouse</td>\n <td>Prof-specialty</td>\n <td>Husband</td>\n <td>Asian-Pac-Islander</td>\n <td>Male</td>\n <td>0</td>\n <td>0</td>\n <td>40</td>\n <td>United-States</td>\n <td>&gt;=50k</td>\n </tr>\n <tr>\n <th>4</th>\n <td>42</td>\n <td>Self-emp-not-inc</td>\n <td>82297</td>\n <td>7th-8th</td>\n <td>NaN</td>\n <td>Married-civ-spouse</td>\n <td>Other-service</td>\n <td>Wife</td>\n <td>Black</td>\n <td>Female</td>\n <td>0</td>\n <td>0</td>\n <td>50</td>\n <td>United-States</td>\n <td>&lt;50k</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " age workclass fnlwgt education education-num \\\n0 49 Private 101320 Assoc-acdm 12.0 \n1 44 Private 236746 Masters 14.0 \n2 38 Private 96185 HS-grad NaN \n3 38 Self-emp-inc 112847 Prof-school 15.0 \n4 42 Self-emp-not-inc 82297 7th-8th NaN \n\n marital-status occupation relationship race \\\n0 Married-civ-spouse NaN Wife White \n1 Divorced Exec-managerial Not-in-family White \n2 Divorced NaN Unmarried Black \n3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n4 Married-civ-spouse Other-service Wife Black \n\n sex capital-gain capital-loss hours-per-week native-country salary \n0 Female 0 1902 40 United-States >=50k \n1 Male 10520 0 45 United-States >=50k \n2 Female 0 0 32 United-States <50k \n3 Male 0 0 40 United-States >=50k \n4 Female 0 0 50 United-States <50k "
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "procs = [FillMissing, Categorify, Normalize]\nvalid_idx = range(len(df)-2000, len(df))\ndep_var = 'salary'\ncat_names = ['workclass', 'education', 'marital-status', 'occupation', \n 'relationship', 'race', 'sex', 'native-country']\n\ndata = TabularDataBunch.from_df(path, df, dep_var, valid_idx=valid_idx, procs=procs, cat_names=cat_names)\ndata.show_batch()",
"execution_count": 4,
"outputs": [
{
"data": {
"text/html": "<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th>workclass</th>\n <th>education</th>\n <th>marital-status</th>\n <th>occupation</th>\n <th>relationship</th>\n <th>race</th>\n <th>sex</th>\n <th>native-country</th>\n <th>education-num_na</th>\n <th>capital-loss</th>\n <th>education-num</th>\n <th>fnlwgt</th>\n <th>age</th>\n <th>capital-gain</th>\n <th>hours-per-week</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>Private</td>\n <td>Some-college</td>\n <td>Never-married</td>\n <td>Handlers-cleaners</td>\n <td>Not-in-family</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-0.0297</td>\n <td>0.2848</td>\n <td>-1.3638</td>\n <td>-0.1459</td>\n <td>-0.8437</td>\n <td>&lt;50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>11th</td>\n <td>Never-married</td>\n <td>Other-service</td>\n <td>Own-child</td>\n <td>White</td>\n <td>Female</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-1.2052</td>\n <td>0.4809</td>\n <td>-1.5102</td>\n <td>-0.1459</td>\n <td>-1.6516</td>\n <td>&lt;50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>HS-grad</td>\n <td>Married-civ-spouse</td>\n <td>Sales</td>\n <td>Husband</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>-0.4216</td>\n <td>0.0778</td>\n <td>-0.9244</td>\n <td>-0.1459</td>\n <td>1.9840</td>\n <td>&lt;50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>Doctorate</td>\n <td>Married-civ-spouse</td>\n <td>Prof-specialty</td>\n <td>Husband</td>\n <td>White</td>\n <td>Male</td>\n <td>United-States</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>2.3212</td>\n <td>0.5144</td>\n <td>1.7118</td>\n <td>1.8720</td>\n <td>0.7721</td>\n <td>&gt;=50k</td>\n </tr>\n <tr>\n <td>Private</td>\n <td>Bachelors</td>\n <td>Divorced</td>\n <td>Prof-specialty</td>\n <td>Unmarried</td>\n <td>White</td>\n <td>Female</td>\n <td>Germany</td>\n <td>False</td>\n <td>-0.2168</td>\n <td>1.1457</td>\n <td>-0.5488</td>\n <td>-0.4118</td>\n <td>-0.0232</td>\n <td>-0.0358</td>\n <td>&lt;50k</td>\n </tr>\n </tbody>\n</table>",
"text/plain": "<IPython.core.display.HTML object>"
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "print(data.train_ds.cont_names) # `cont_names` defaults to: set(df)-set(cat_names)-{dep_var}",
"execution_count": 5,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "['capital-loss', 'education-num', 'fnlwgt', 'age', 'capital-gain', 'hours-per-week']\n"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data.train_ds.x.codes.shape, data.train_ds.x.conts.shape, data.train_ds.y.items.shape",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "((30561, 9), (30561, 6), (30561,))"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "X = np.concatenate((data.train_ds.x.codes, data.train_ds.x.conts), axis=1)\ny = data.train_ds.y.items\nX.shape, y.shape",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "((30561, 15), (30561,))"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Split into training and validation sets"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=None, random_state=42)\nX_train.shape, X_valid.shape, X_train.shape[0] + X_valid.shape[0]",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "((22920, 15), (7641, 15), 30561)"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Train RandomForest model"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "m = RandomForestClassifier(n_jobs=-1, n_estimators=10)\nm.fit(X_train, y_train)\nm.score(X_train,y_train), m.score(X_valid,y_valid)",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "(0.9872600349040139, 0.8564324041355843)"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
]
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/b1132c92c690287ee9ea613c3add56bd"
},
"gist": {
"id": "b1132c92c690287ee9ea613c3add56bd",
"data": {
"description": "proc_df for fastai_v1",
"public": true
}
},
"kernelspec": {
"name": "conda-env-fastai-3.6-py",
"display_name": "Python [conda env:fastai-3.6] *",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment