Created
March 9, 2019 11:19
-
-
Save regonn/511038337a145810794bbd1f9fdc8079 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from ludwig import LudwigModel\n", | |
"import polar_bear as pb\n", | |
"import pandas as pd" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Name</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Ticket</th>\n", | |
" <th>Fare</th>\n", | |
" <th>Cabin</th>\n", | |
" <th>Embarked</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Braund, Mr. Owen Harris</td>\n", | |
" <td>male</td>\n", | |
" <td>22.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>A/5 21171</td>\n", | |
" <td>7.2500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", | |
" <td>female</td>\n", | |
" <td>38.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>PC 17599</td>\n", | |
" <td>71.2833</td>\n", | |
" <td>C85</td>\n", | |
" <td>C</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>Heikkinen, Miss. Laina</td>\n", | |
" <td>female</td>\n", | |
" <td>26.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>STON/O2. 3101282</td>\n", | |
" <td>7.9250</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", | |
" <td>female</td>\n", | |
" <td>35.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>113803</td>\n", | |
" <td>53.1000</td>\n", | |
" <td>C123</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Allen, Mr. William Henry</td>\n", | |
" <td>male</td>\n", | |
" <td>35.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>373450</td>\n", | |
" <td>8.0500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived Pclass \\\n", | |
"0 1 0 3 \n", | |
"1 2 1 1 \n", | |
"2 3 1 3 \n", | |
"3 4 1 1 \n", | |
"4 5 0 3 \n", | |
"\n", | |
" Name Sex Age SibSp \\\n", | |
"0 Braund, Mr. Owen Harris male 22.0 1 \n", | |
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", | |
"2 Heikkinen, Miss. Laina female 26.0 0 \n", | |
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", | |
"4 Allen, Mr. William Henry male 35.0 0 \n", | |
"\n", | |
" Parch Ticket Fare Cabin Embarked \n", | |
"0 0 A/5 21171 7.2500 NaN S \n", | |
"1 0 PC 17599 71.2833 C85 C \n", | |
"2 0 STON/O2. 3101282 7.9250 NaN S \n", | |
"3 0 113803 53.1000 C123 S \n", | |
"4 0 373450 8.0500 NaN S " | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df = pd.read_csv('train_titanic.csv')\n", | |
"test_df = pd.read_csv('test_titanic.csv')\n", | |
"train_df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'output_features': [{'name': 'Survived', 'type': 'binary'}], 'input_features': [{'name': 'Pclass', 'type': 'category'}, {'name': 'Sex', 'type': 'category'}, {'name': 'Age', 'type': 'numerical'}, {'name': 'SibSp', 'type': 'category'}, {'name': 'Parch', 'type': 'category'}, {'name': 'Fare', 'type': 'numerical'}, {'name': 'Cabin', 'type': 'category'}, {'name': 'Embarked', 'type': 'category'}]}\n" | |
] | |
} | |
], | |
"source": [ | |
"model_definition = pb.make_model_definition_for_ludwig(train_df, 'Survived')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'output_features': [{'name': 'Survived', 'type': 'binary'}],\n", | |
" 'input_features': [{'name': 'Pclass', 'type': 'category'},\n", | |
" {'name': 'Sex', 'type': 'category'},\n", | |
" {'name': 'Age', 'type': 'numerical'},\n", | |
" {'name': 'SibSp', 'type': 'category'},\n", | |
" {'name': 'Parch', 'type': 'category'},\n", | |
" {'name': 'Fare', 'type': 'numerical'},\n", | |
" {'name': 'Cabin', 'type': 'category'},\n", | |
" {'name': 'Embarked', 'type': 'category'}]}" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model_definition" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = LudwigModel(model_definition)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/ludwig/features/numerical_feature.py:63: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n", | |
" np.float32).as_matrix()\n", | |
"/home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/ludwig/features/binary_feature.py:62: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n", | |
" np.bool_).as_matrix()\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Colocations handled automatically by placer.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Colocations handled automatically by placer.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/ludwig/features/binary_feature.py:180: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.cast instead.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/ludwig/features/binary_feature.py:180: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.cast instead.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/tensorflow/python/ops/array_grad.py:425: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.cast instead.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From /home/regonn/.local/share/virtualenvs/regonn-667eHjES/lib/python3.7/site-packages/tensorflow/python/ops/array_grad.py:425: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use tf.cast instead.\n" | |
] | |
} | |
], | |
"source": [ | |
"train_stats = model.train(train_df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predictions = model.predict(test_df)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Survived_predictions</th>\n", | |
" <th>Survived_probabilities</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>False</td>\n", | |
" <td>0.285479</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>True</td>\n", | |
" <td>0.531202</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>False</td>\n", | |
" <td>0.437962</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>False</td>\n", | |
" <td>0.123611</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>True</td>\n", | |
" <td>0.614788</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Survived_predictions Survived_probabilities\n", | |
"0 False 0.285479\n", | |
"1 True 0.531202\n", | |
"2 False 0.437962\n", | |
"3 False 0.123611\n", | |
"4 True 0.614788" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"predictions.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>892</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>893</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>894</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>895</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>896</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived\n", | |
"0 892 False\n", | |
"1 893 True\n", | |
"2 894 False\n", | |
"3 895 False\n", | |
"4 896 True" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sub = pd.DataFrame({'PassengerId': test_df['PassengerId'], 'Survived': predictions['Survived_predictions']})\n", | |
"sub.to_csv('ludwig.csv', index = False, float_format='%1d')\n", | |
"sub.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Titanic Score: 0.76076" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment