Created
March 21, 2018 21:29
-
-
Save tonygentilcore/d41c8b631c96ed76c024983b73c8bece to your computer and use it in GitHub Desktop.
Titanic Kaggle w/ fast.ai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%reload_ext autoreload\n", | |
"%autoreload 2\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"sys.path.append(\"/home/tonygentilcore/fastai/courses/dl1/\")\n", | |
"from fastai.structured import *\n", | |
"from fastai.column_data import *\n", | |
"np.set_printoptions(threshold=50, edgeitems=20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"PATH = '/home/tonygentilcore/.kaggle/competitions/titanic'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"** Explore data **" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gender_submission.csv models test.csv tmp train.csv\r\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls {PATH}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train = pd.read_csv(f'{PATH}/train.csv')\n", | |
"test = pd.read_csv(f'{PATH}/test.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Fare</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>891.000000</td>\n", | |
" <td>891.000000</td>\n", | |
" <td>891.000000</td>\n", | |
" <td>714.000000</td>\n", | |
" <td>891.000000</td>\n", | |
" <td>891.000000</td>\n", | |
" <td>891.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>446.000000</td>\n", | |
" <td>0.383838</td>\n", | |
" <td>2.308642</td>\n", | |
" <td>29.699118</td>\n", | |
" <td>0.523008</td>\n", | |
" <td>0.381594</td>\n", | |
" <td>32.204208</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>257.353842</td>\n", | |
" <td>0.486592</td>\n", | |
" <td>0.836071</td>\n", | |
" <td>14.526497</td>\n", | |
" <td>1.102743</td>\n", | |
" <td>0.806057</td>\n", | |
" <td>49.693429</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.420000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>223.500000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>2.000000</td>\n", | |
" <td>20.125000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>7.910400</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>446.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>28.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>14.454200</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>668.500000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>38.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>31.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>891.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>80.000000</td>\n", | |
" <td>8.000000</td>\n", | |
" <td>6.000000</td>\n", | |
" <td>512.329200</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived Pclass Age SibSp \\\n", | |
"count 891.000000 891.000000 891.000000 714.000000 891.000000 \n", | |
"mean 446.000000 0.383838 2.308642 29.699118 0.523008 \n", | |
"std 257.353842 0.486592 0.836071 14.526497 1.102743 \n", | |
"min 1.000000 0.000000 1.000000 0.420000 0.000000 \n", | |
"25% 223.500000 0.000000 2.000000 20.125000 0.000000 \n", | |
"50% 446.000000 0.000000 3.000000 28.000000 0.000000 \n", | |
"75% 668.500000 1.000000 3.000000 38.000000 1.000000 \n", | |
"max 891.000000 1.000000 3.000000 80.000000 8.000000 \n", | |
"\n", | |
" Parch Fare \n", | |
"count 891.000000 891.000000 \n", | |
"mean 0.381594 32.204208 \n", | |
"std 0.806057 49.693429 \n", | |
"min 0.000000 0.000000 \n", | |
"25% 0.000000 7.910400 \n", | |
"50% 0.000000 14.454200 \n", | |
"75% 0.000000 31.000000 \n", | |
"max 6.000000 512.329200 " | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train.describe()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Name</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Ticket</th>\n", | |
" <th>Fare</th>\n", | |
" <th>Cabin</th>\n", | |
" <th>Embarked</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Braund, Mr. Owen Harris</td>\n", | |
" <td>male</td>\n", | |
" <td>22.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>A/5 21171</td>\n", | |
" <td>7.2500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", | |
" <td>female</td>\n", | |
" <td>38.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>PC 17599</td>\n", | |
" <td>71.2833</td>\n", | |
" <td>C85</td>\n", | |
" <td>C</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>Heikkinen, Miss. Laina</td>\n", | |
" <td>female</td>\n", | |
" <td>26.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>STON/O2. 3101282</td>\n", | |
" <td>7.9250</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", | |
" <td>female</td>\n", | |
" <td>35.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>113803</td>\n", | |
" <td>53.1000</td>\n", | |
" <td>C123</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Allen, Mr. William Henry</td>\n", | |
" <td>male</td>\n", | |
" <td>35.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>373450</td>\n", | |
" <td>8.0500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived Pclass \\\n", | |
"0 1 0 3 \n", | |
"1 2 1 1 \n", | |
"2 3 1 3 \n", | |
"3 4 1 1 \n", | |
"4 5 0 3 \n", | |
"\n", | |
" Name Sex Age SibSp \\\n", | |
"0 Braund, Mr. Owen Harris male 22.0 1 \n", | |
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", | |
"2 Heikkinen, Miss. Laina female 26.0 0 \n", | |
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", | |
"4 Allen, Mr. William Henry male 35.0 0 \n", | |
"\n", | |
" Parch Ticket Fare Cabin Embarked \n", | |
"0 0 A/5 21171 7.2500 NaN S \n", | |
"1 0 PC 17599 71.2833 C85 C \n", | |
"2 0 STON/O2. 3101282 7.9250 NaN S \n", | |
"3 0 113803 53.1000 C123 S \n", | |
"4 0 373450 8.0500 NaN S " | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Name</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Ticket</th>\n", | |
" <th>Fare</th>\n", | |
" <th>Cabin</th>\n", | |
" <th>Embarked</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>892</td>\n", | |
" <td>3</td>\n", | |
" <td>Kelly, Mr. James</td>\n", | |
" <td>male</td>\n", | |
" <td>34.5</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>330911</td>\n", | |
" <td>7.8292</td>\n", | |
" <td>NaN</td>\n", | |
" <td>Q</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>893</td>\n", | |
" <td>3</td>\n", | |
" <td>Wilkes, Mrs. James (Ellen Needs)</td>\n", | |
" <td>female</td>\n", | |
" <td>47.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>363272</td>\n", | |
" <td>7.0000</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>894</td>\n", | |
" <td>2</td>\n", | |
" <td>Myles, Mr. Thomas Francis</td>\n", | |
" <td>male</td>\n", | |
" <td>62.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>240276</td>\n", | |
" <td>9.6875</td>\n", | |
" <td>NaN</td>\n", | |
" <td>Q</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>895</td>\n", | |
" <td>3</td>\n", | |
" <td>Wirz, Mr. Albert</td>\n", | |
" <td>male</td>\n", | |
" <td>27.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>315154</td>\n", | |
" <td>8.6625</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>896</td>\n", | |
" <td>3</td>\n", | |
" <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n", | |
" <td>female</td>\n", | |
" <td>22.0</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>3101298</td>\n", | |
" <td>12.2875</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Pclass Name Sex \\\n", | |
"0 892 3 Kelly, Mr. James male \n", | |
"1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n", | |
"2 894 2 Myles, Mr. Thomas Francis male \n", | |
"3 895 3 Wirz, Mr. Albert male \n", | |
"4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n", | |
"\n", | |
" Age SibSp Parch Ticket Fare Cabin Embarked \n", | |
"0 34.5 0 0 330911 7.8292 NaN Q \n", | |
"1 47.0 1 0 363272 7.0000 NaN S \n", | |
"2 62.0 0 0 240276 9.6875 NaN Q \n", | |
"3 27.0 0 0 315154 8.6625 NaN S \n", | |
"4 22.0 1 1 3101298 12.2875 NaN S " | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"** Prepare data **" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index = 'PassengerId'\n", | |
"dep = 'Survived'\n", | |
"cat_vars = ['Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']\n", | |
"contin_vars = ['Age']\n", | |
"drop_vars = ['Name', 'Ticket', 'Cabin', 'Fare']\n", | |
"\n", | |
"test.set_index(index)\n", | |
"train.set_index(index)\n", | |
"\n", | |
"for v in cat_vars:\n", | |
" test[v] = test[v].astype('category').cat.as_ordered()\n", | |
" train[v] = train[v].astype('category').cat.as_ordered()\n", | |
"\n", | |
"for v in contin_vars:\n", | |
" test[v] = test[v].astype('float32')\n", | |
" train[v] = train[v].astype('float32')\n", | |
" \n", | |
"for v in drop_vars:\n", | |
" if v in test:\n", | |
" test.drop(v, axis=1, inplace=True)\n", | |
" train.drop(v, axis=1, inplace=True)\n", | |
"\n", | |
"test[dep] = np.nan\n", | |
" \n", | |
"apply_cats(test, train)\n", | |
"\n", | |
"df, y, nas, mapper = proc_df(train, dep, do_scale=True, skip_flds=[index])\n", | |
"df_test, _, nas, mapper = proc_df(test, dep, do_scale=True, skip_flds=[index], mapper=mapper, na_dict=nas)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Embarked</th>\n", | |
" <th>Age_na</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>3</td>\n", | |
" <td>2</td>\n", | |
" <td>-0.565736</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>0.663861</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>-0.258337</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>0.433312</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>3</td>\n", | |
" <td>2</td>\n", | |
" <td>0.433312</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Pclass Sex Age SibSp Parch Embarked Age_na\n", | |
"0 3 2 -0.565736 2 1 3 -0.497895\n", | |
"1 1 1 0.663861 2 1 1 -0.497895\n", | |
"2 3 1 -0.258337 1 1 3 -0.497895\n", | |
"3 1 1 0.433312 2 1 3 -0.497895\n", | |
"4 3 2 0.433312 1 1 3 -0.497895" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Embarked</th>\n", | |
" <th>Age_na</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>3</td>\n", | |
" <td>2</td>\n", | |
" <td>0.394887</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>2</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>1.355510</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>2</td>\n", | |
" <td>2</td>\n", | |
" <td>2.508257</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>2</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3</td>\n", | |
" <td>2</td>\n", | |
" <td>-0.181487</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>-0.565736</td>\n", | |
" <td>2</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.497895</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Pclass Sex Age SibSp Parch Embarked Age_na\n", | |
"0 3 2 0.394887 1 1 2 -0.497895\n", | |
"1 3 1 1.355510 2 1 3 -0.497895\n", | |
"2 2 2 2.508257 1 1 2 -0.497895\n", | |
"3 3 2 -0.181487 1 1 3 -0.497895\n", | |
"4 3 1 -0.565736 2 2 3 -0.497895" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df_test.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"** Create model/learner **" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cat_sz = [(c, len(train[c].cat.categories)+1) for c in cat_vars]\n", | |
"emb_szs = [(c, min(50, (c+1)//2)) for _,c in cat_sz]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n = len(train)\n", | |
"val_idxs = get_cv_idxs(n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"md = ColumnarModelData.from_data_frame(PATH, val_idxs, df, y.astype(np.float32),\n", | |
" cat_flds=cat_vars, bs=128, test_df=df_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n", | |
" 0.04, 1, [1000,500], [0.001,0.01], y_range=[0, 1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"** Train **" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8b2f29c343ee4ceb9de09bac7e4074dd", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/html": [ | |
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
"<p>\n", | |
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
" that the widgets JavaScript is still loading. If this message persists, it\n", | |
" likely means that the widgets JavaScript library is either not installed or\n", | |
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
" Widgets Documentation</a> for setup instructions.\n", | |
"</p>\n", | |
"<p>\n", | |
" If you're reading this message in another frontend (for example, a static\n", | |
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
" it may mean that your frontend doesn't currently support widgets.\n", | |
"</p>\n" | |
], | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 0.26849 0.589888 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAEOCAYAAACuOOGFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAEz5JREFUeJzt3X2wXHV9x/H3B1KxVoyAASEhhhasE8enumCp2qFVIjrVqFCJbW1U2vSJanVqi7UdFZ2KT7U+1whWaqug0GIEa3hQqm0VcoNSCEpJox1SUKJBCqIw6Ld/7EGW2725m9zfvZubvF8zO3vO7/zOOd+998753N85u2dTVUiSNFP7jLsASdKewUCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDWxYNwFzKWHPexhtWzZsnGXIUnzysaNG79dVYum67dXBcqyZcuYmJgYdxmSNK8k+e9R+nnKS5LUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITYw2UJCckuT7J5iSnDVm+X5Jzu+VXJFk2afnSJHck+eO5qlmSNNzYAiXJvsB7gWcCy4EXJlk+qdspwK1VdSTwDuDNk5a/A/jn2a5VkjS9cY5QjgE2V9WWqrobOAdYOanPSuDsbvo84GlJApDkucAWYNMc1StJ2oFxBspi4MaB+a1d29A+VXUPcBtwUJKfAv4UeP0c1ClJGsE4AyVD2mrEPq8H3lFVd0y7k2RNkokkE9u2bduFMiVJo1gwxn1vBQ4fmF8C3DRFn61JFgALge3Ak4CTkrwFeCjwoyQ/qKr3TN5JVa0F1gL0er3JgSVJamScgbIBOCrJEcD/AKuAX5vUZx2wGvgicBLw2aoq4Kn3dkjyOuCOYWEiSZo7YwuUqronyanAemBf4ENVtSnJ6cBEVa0DzgI+kmQz/ZHJqnHVK0nasfT/4d879Hq9mpiYGHcZkjSvJNlYVb3p+vlJeUlSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNjDVQkpyQ5Pokm5OcNmT5fknO7ZZfkWRZ1358ko1Jrumef3mua5ck3d/YAiXJvsB7gWcCy4EXJlk+qdspwK1VdSTwDuDNXfu3gWdX1WOA1cBH5qZqSdJUxjlCOQbYXFVbqupu4Bxg5aQ+K4Gzu+nzgKclSVV9uapu6to3AQ9Mst+cVC1JGmqcgbIYuHFgfmvXNrRPVd0D3AYcNKnPicCXq+quWapTkjSCBWPcd4a01c70SfJo+qfBVky5k2QNsAZg6dKlO1+lJGkk4xyhbAUOH5hfAtw0VZ8kC4CFwPZufgnwT8BvVtV/TbWTqlpbVb2q6i1atKhh+ZKkQeMMlA3AUUmOSPIAYBWwblKfdfQvugOcBHy2qirJQ4GLgFdX1b/NWcWSpCmNLVC6ayKnAuuBrwIfr6pNSU5P8pyu21nAQUk2A68E7n1r8anAkcBfJPlK9zh4jl+CJGlAqiZftthz9Xq9mpiYGHcZkjSvJNlYVb3p+vlJeUlSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1MRIgZLk5Ukekr6zklyVZMVsFydJmj9GHaG8tKr+F1gBLAJeApwxa1VJkuadUQMl3fOzgL+tqqsH2iRJGjlQNia5mH6grE+yP/Cjme48yQlJrk+yOclpQ5bvl+TcbvkVSZYNLHt11359kmfMtBZJ0swsGLHfKcDjgS1VdWeSA+mf9tplSfYF3gscD2wFNiRZV1XXTdrvrVV1ZJJVwJuBk5MsB1YBjwYOAy5N8siq+uFMapIk7bpRRyjHAtdX1XeT/Abw58BtM9z3McDmqtpSVXcD5wArJ/VZCZzdTZ8HPC1JuvZzququqvo6sLnbniRpTEYNlPcDdyZ5HPAnwH8DfzfDfS8GbhyY39q1De1TVffQD7GDRlxXkjSHRg2Ue6qq6I8M3llV7wT2n+G+h13UrxH7jLJufwPJmiQTSSa2bdu2kyVKkkY1aqDcnuTVwIuAi7rrHz8xw31vBQ4fmF8C3DRVnyQLgIXA9hHXBaCq1lZVr6p6ixYtmmHJkqSpjBooJwN30f88yjfpn1566wz3vQE4KskRSR5A/yL7ukl91gGru+mTgM92I6V1wKruXWBHAEcBV86wHknSDIz0Lq+q+maSfwCOTvIrwJVVNaNrKFV1T5JTgfXAvsCHqmpTktOBiapaB5wFfCTJZvojk1XdupuSfBy4DrgH+APf4SVJ45X+P/zTdEpeQH9Ecjn96xdPBV5VVefNanWN9Xq9mpiYGHcZkjSvJNlYVb3p+o36OZTXAEdX1S3dxhcBl9J/K68kSSNfQ9nn3jDpfGcn1pUk7QVGHaF8Jsl64GPd/MnAp2enJEnSfDTqRflXJTkReDL9ayhrq+qfZrUySdK8MuoIhao6Hzh/FmuRJM1jOwyUJLcz/BPoAaqqHjIrVUmS5p0dBkpVzfT2KpKkvYTv1JIkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWpiLIGS5MAklyS5oXs+YIp+q7s+NyRZ3bU9KMlFSb6WZFOSM+a2eknSMOMaoZwGXFZVRwGXdfP3k+RA4LXAk4BjgNcOBM/bqupRwBOAJyd55tyULUmayrgCZSVwdjd9NvDcIX2eAVxSVdur6lbgEuCEqrqzqj4HUFV3A1cBS+agZknSDowrUA6pqpsBuueDh/RZDNw4ML+1a/uxJA8Fnk1/lCNJGqMFs7XhJJcCDx+y6DWjbmJIWw1sfwHwMeBdVbVlB3WsAdYALF26dMRdS5J21qwFSlU9faplSb6V5NCqujnJocAtQ7ptBY4bmF8CXD4wvxa4oar+epo61nZ96fV6taO+kqRdN65TXuuA1d30auCTQ/qsB1YkOaC7GL+iayPJG4GFwB/NQa2SpBGMK1DOAI5PcgNwfDdPkl6SMwGqajvwBmBD9zi9qrYnWUL/tNly4KokX0nyW+N4EZKk+6Rq7zkL1Ov1amJiYtxlSNK8kmRjVfWm6+cn5SVJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJamIsgZLkwCSXJLmhez5gin6ruz43JFk9ZPm6JNfOfsWSpOmMa4RyGnBZVR0FXNbN30+SA4HXAk8CjgFeOxg8SZ4P3DE35UqSpjOuQFkJnN1Nnw08d0ifZwCXVNX2qroVuAQ4ASDJg4FXAm+cg1olSSMYV6AcUlU3A3TPBw/psxi4cWB+a9cG8Abg7cCds1mkJGl0C2Zrw0kuBR4+ZNFrRt3EkLZK8njgyKp6RZJlI9SxBlgDsHTp0hF3LUnaWbMWKFX19KmWJflWkkOr6uYkhwK3DOm2FThuYH4JcDlwLPDEJN+gX//BSS6vquMYoqrWAmsBer1e7fwrkSSNYlynvNYB975razXwySF91gMrkhzQXYxfAayvqvdX1WFVtQx4CvCfU4WJJGnujCtQzgCOT3IDcHw3T5JekjMBqmo7/WslG7rH6V2bJGk3lKq95yxQr9eriYmJcZchSfNKko1V1Zuun5+UlyQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1kaoadw1zJsk24LvAbbuw+sOAb7etSDuwkF37Pe3OdtfXNK66Znu/rbffansz2c6urjvT49cjqmrRdJ32qkABSLK2qtbswnoTVdWbjZr0/+3q72l3tru+pnHVNdv7bb39VtubyXZ29+PX3njK61PjLkAj2RN/T7vraxpXXbO939bbb7W9mWxnd/0bAvbCEcqucoQiab5yhLL7WTvuAiRpF83J8csRiiSpCUcokqQmDBRJUhMGiiSpCQNlFyX5qSRnJ/lgkl8fdz2SNKokP53krCTntdyugTIgyYeS3JLk2kntJyS5PsnmJKd1zc8Hzquq3waeM+fFStKAnTl+VdWWqjqldQ0Gyv19GDhhsCHJvsB7gWcCy4EXJlkOLAFu7Lr9cA5rlKRhPszox69ZYaAMqKrPA9snNR8DbO4S/W7gHGAlsJV+qIA/R0ljtpPHr1nhgXB6i7lvJAL9IFkM/CNwYpL3s5vfDkHSXmvo8SvJQUn+BnhCkle32tmCVhvag2VIW1XV94CXzHUxkrQTpjp+fQf43dY7c4Qyva3A4QPzS4CbxlSLJO2MOT1+GSjT2wAcleSIJA8AVgHrxlyTJI1iTo9fBsqAJB8Dvgj8bJKtSU6pqnuAU4H1wFeBj1fVpnHWKUmT7Q7HL28OKUlqwhGKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCjabSW5Yw728ZyBrySYE0mOS/ILu7DeE5Kc2U2/OMl72le385Ism3zL9CF9FiX5zFzVpPEwULTH627hPVRVrauqM2Zhnzu6T95xwE4HCvBnwLt3qaAxq6ptwM1JnjzuWjR7DBTNC0lelWRDkv9I8vqB9guSbEyyKcmagfY7kpye5Arg2CTfSPL6JFcluSbJo7p+P/5PP8mHk7wryb8n2ZLkpK59nyTv6/ZxYZJP37tsUo2XJ/nLJP8CvDzJs5NckeTLSS5NckiSZfRvyveKJF9J8tTuv/fzu9e3YdhBN8n+wGOr6uohyx6R5LLuZ3NZkqVd+88k+VK3zdOHjfi6bx69KMnVSa5NcnLXfnT3c7g6yZVJ9u9GIl/ofoZXDRtlJdk3yVsHfle/M7D4AsBvN92TVZUPH7vlA7ije14BrKV/59R9gAuBX+yWHdg9/yRwLXBQN1/ACwa29Q3gD7vp3wfO7KZfDLynm/4w8IluH8vpf48EwEnAp7v2hwO3AicNqfdy4H0D8wdw390ofgt4ezf9OuCPB/p9FHhKN70U+OqQbf8ScP7A/GDdnwJWd9MvBS7opi8EXthN/+69P89J2z0R+ODA/ELgAcAW4Oiu7SH070z+IOCBXdtRwEQ3vQy4tpteA/x5N70fMAEc0c0vBq4Z99+Vj9l7ePt6zQcruseXu/kH0z+gfR54WZLnde2Hd+3fof8tmudP2s4/ds8b6X+F8zAXVNWPgOuSHNK1PQX4RNf+zSSf20Gt5w5MLwHOTXIo/YP016dY5+nA8uTHdxp/SJL9q+r2gT6HAtumWP/YgdfzEeAtA+3P7aY/CrxtyLrXAG9L8mbgwqr6QpLHADdX1QaAqvpf6I9mgPckeTz9n+8jh2xvBfDYgRHcQvq/k68DtwCHTfEatAcwUDQfBHhTVX3gfo3JcfQPxsdW1Z1JLgce2C3+QVVN/mrmu7rnHzL13/5dA9OZ9DyK7w1Mvxv4q6pa19X6uinW2Yf+a/j+Drb7fe57bdMZ+QZ9VfWfSZ4IPAt4U5KL6Z+aGraNVwDfAh7X1fyDIX1CfyS4fsiyB9J/HdpDeQ1F88F64KVJHgyQZHGSg+n/93trFyaPAn5+lvb/r/S/nXOfbtRy3IjrLQT+p5tePdB+O7D/wPzF9O8IC0A3Apjsq8CRU+zn3+nflhz61yj+tZv+Ev1TWgwsv58khwF3VtXf0x/B/BzwNeCwJEd3ffbv3mSwkP7I5UfAi4Bhb3ZYD/xekp/o1n1kN7KB/ohmh+8G0/xmoGi3V1UX0z9l88Uk1wDn0T8gfwZYkOQ/gDfQP4DOhvPpf1HRtcAHgCuA20ZY73XAJ5J8Afj2QPungOfde1EeeBnQ6y5iX8eQb9Krqq8BC7uL85O9DHhJ93N4EfDyrv2PgFcmuZL+KbNhNT8GuDLJV4DXAG+s/nePnwy8O8nVwCX0RxfvA1Yn+RL9cPjekO2dCVwHXNW9lfgD3Dca/CXgoiHraA/h7eulESR5cFXdkeQg4ErgyVX1zTmu4RXA7VV15oj9HwR8v6oqySr6F+hXzmqRO67n88DKqrp1XDVodnkNRRrNhUkeSv/i+hvmOkw67wd+dSf6P5H+RfQA36X/DrCxSLKI/vUkw2QP5ghFktSE11AkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWri/wAuNq8iO7eWkwAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"m.lr_find()\n", | |
"m.sched.plot(100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lr = 1e-3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "04b6976ddcbd4800a815ff98ee348383", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/html": [ | |
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
"<p>\n", | |
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
" that the widgets JavaScript is still loading. If this message persists, it\n", | |
" likely means that the widgets JavaScript library is either not installed or\n", | |
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
" Widgets Documentation</a> for setup instructions.\n", | |
"</p>\n", | |
"<p>\n", | |
" If you're reading this message in another frontend (for example, a static\n", | |
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
" it may mean that your frontend doesn't currently support widgets.\n", | |
"</p>\n" | |
], | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=28), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 0.29958 0.289071 \n", | |
" 1 0.266482 0.225575 \n", | |
" 2 0.243882 0.230261 \n", | |
" 3 0.230642 0.224803 \n", | |
" 4 0.220913 0.205747 \n", | |
" 5 0.210903 0.209014 \n", | |
" 6 0.202561 0.203815 \n", | |
" 7 0.194041 0.199481 \n", | |
" 8 0.18833 0.194539 \n", | |
" 9 0.182456 0.188606 \n", | |
" 10 0.177253 0.184244 \n", | |
" 11 0.173428 0.183238 \n", | |
" 12 0.169191 0.14376 \n", | |
" 13 0.164396 0.151117 \n", | |
" 14 0.159777 0.136738 \n", | |
" 15 0.155043 0.135023 \n", | |
" 16 0.151094 0.135569 \n", | |
" 17 0.147484 0.138293 \n", | |
" 18 0.144015 0.136753 \n", | |
" 19 0.141021 0.138247 \n", | |
" 20 0.137857 0.136863 \n", | |
" 21 0.135251 0.137623 \n", | |
" 22 0.133364 0.137839 \n", | |
" 23 0.13165 0.137135 \n", | |
" 24 0.129502 0.137729 \n", | |
" 25 0.127662 0.137821 \n", | |
" 26 0.126263 0.137893 \n", | |
" 27 0.124427 0.137869 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[0.13786882]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"m.fit(lr, 3, cycle_len=4, cycle_mult=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m.save('val0')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m.load('val0')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"** Prepare submission **" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x,y=m.predict_with_targs()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"178" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pred_test = m.predict(True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test[dep] = pred_test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"csv_fn=f'{PATH}/tmp/sub.csv'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Embarked</th>\n", | |
" <th>Survived</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>892</td>\n", | |
" <td>3</td>\n", | |
" <td>male</td>\n", | |
" <td>34.5</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>Q</td>\n", | |
" <td>0.124961</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>893</td>\n", | |
" <td>3</td>\n", | |
" <td>female</td>\n", | |
" <td>47.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>S</td>\n", | |
" <td>0.326619</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>894</td>\n", | |
" <td>2</td>\n", | |
" <td>male</td>\n", | |
" <td>62.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>Q</td>\n", | |
" <td>0.191005</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>895</td>\n", | |
" <td>3</td>\n", | |
" <td>male</td>\n", | |
" <td>27.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>S</td>\n", | |
" <td>0.127235</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>896</td>\n", | |
" <td>3</td>\n", | |
" <td>female</td>\n", | |
" <td>22.0</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>S</td>\n", | |
" <td>0.355572</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Pclass Sex Age SibSp Parch Embarked Survived\n", | |
"0 892 3 male 34.5 0 0 Q 0.124961\n", | |
"1 893 3 female 47.0 1 0 S 0.326619\n", | |
"2 894 2 male 62.0 0 0 Q 0.191005\n", | |
"3 895 3 male 27.0 0 0 S 0.127235\n", | |
"4 896 3 female 22.0 1 1 S 0.355572" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>892</td>\n", | |
" <td>0.124961</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>893</td>\n", | |
" <td>0.326619</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>894</td>\n", | |
" <td>0.191005</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>895</td>\n", | |
" <td>0.127235</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>896</td>\n", | |
" <td>0.355572</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived\n", | |
"0 892 0.124961\n", | |
"1 893 0.326619\n", | |
"2 894 0.191005\n", | |
"3 895 0.127235\n", | |
"4 896 0.355572" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sub = test[[index, dep]]\n", | |
"sub.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<a href='/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv' target='_blank'>/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv</a><br>" | |
], | |
"text/plain": [ | |
"/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sub.to_csv(csv_fn, index=False)\n", | |
"FileLink(csv_fn)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment