Skip to content

Instantly share code, notes, and snippets.

@MachineLearningIsEasy
Created March 30, 2020 13:42
Show Gist options
  • Save MachineLearningIsEasy/218837e7a154179906575222a88687b7 to your computer and use it in GitHub Desktop.
Save MachineLearningIsEasy/218837e7a154179906575222a88687b7 to your computer and use it in GitHub Desktop.
Sklearn and pandas example
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"colab": {
"name": "Titanic.ipynb",
"provenance": []
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "fdJ1ZC2ursQP",
"colab_type": "text"
},
"source": [
"### Обработка категориальных признаков"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4D98IQ01rsQV",
"colab_type": "code",
"outputId": "a6f3c227-f418-4cf4-e3aa-09c11cb42f38",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 663
}
},
"source": [
"from IPython.display import Image\n",
"Image(url= \"https://static1.squarespace.com/static/5006453fe4b09ef2252ba068/5095eabce4b06cb305058603/5095eabce4b02d37bef4c24c/1352002236895/100_anniversary_titanic_sinking_by_esai8mellows-d4xbme8.jpg\")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<img src=\"https://static1.squarespace.com/static/5006453fe4b09ef2252ba068/5095eabce4b06cb305058603/5095eabce4b02d37bef4c24c/1352002236895/100_anniversary_titanic_sinking_by_esai8mellows-d4xbme8.jpg\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"tags": []
},
"execution_count": 1
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VYsR44p-rwFl",
"colab_type": "code",
"colab": {}
},
"source": [
"from google.colab import files\n",
"files.upload()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "w_rCkruMrsQf",
"colab_type": "code",
"colab": {}
},
"source": [
"import pandas as pd\n",
"data = pd.read_csv(\"train.csv\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5zRdCObZrsQi",
"colab_type": "code",
"outputId": "b8655769-b6ee-47ab-bdbf-069a3db8bede",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 443
}
},
"source": [
"data.head(10)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Moran, Mr. James</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>330877</td>\n",
" <td>8.4583</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>McCarthy, Mr. Timothy J</td>\n",
" <td>male</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>17463</td>\n",
" <td>51.8625</td>\n",
" <td>E46</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Palsson, Master. Gosta Leonard</td>\n",
" <td>male</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>349909</td>\n",
" <td>21.0750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n",
" <td>female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>347742</td>\n",
" <td>11.1333</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n",
" <td>female</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>237736</td>\n",
" <td>30.0708</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Survived Pclass ... Fare Cabin Embarked\n",
"0 1 0 3 ... 7.2500 NaN S\n",
"1 2 1 1 ... 71.2833 C85 C\n",
"2 3 1 3 ... 7.9250 NaN S\n",
"3 4 1 1 ... 53.1000 C123 S\n",
"4 5 0 3 ... 8.0500 NaN S\n",
"5 6 0 3 ... 8.4583 NaN Q\n",
"6 7 0 1 ... 51.8625 E46 S\n",
"7 8 0 3 ... 21.0750 NaN S\n",
"8 9 1 3 ... 11.1333 NaN S\n",
"9 10 1 2 ... 30.0708 NaN C\n",
"\n",
"[10 rows x 12 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "56SImkYwrsQm",
"colab_type": "code",
"outputId": "716f4e8a-fffa-4606-cee6-40883bffb495",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 308
}
},
"source": [
"data.info()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 891 entries, 0 to 890\n",
"Data columns (total 12 columns):\n",
"PassengerId 891 non-null int64\n",
"Survived 891 non-null int64\n",
"Pclass 891 non-null int64\n",
"Name 891 non-null object\n",
"Sex 891 non-null object\n",
"Age 714 non-null float64\n",
"SibSp 891 non-null int64\n",
"Parch 891 non-null int64\n",
"Ticket 891 non-null object\n",
"Fare 891 non-null float64\n",
"Cabin 204 non-null object\n",
"Embarked 889 non-null object\n",
"dtypes: float64(2), int64(5), object(5)\n",
"memory usage: 83.7+ KB\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ffZevGEbrsQs",
"colab_type": "code",
"outputId": "146b8b82-25ca-4c4d-8a34-e0ebe959c5a8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 240
}
},
"source": [
"data.isnull().sum()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"PassengerId 0\n",
"Survived 0\n",
"Pclass 0\n",
"Name 0\n",
"Sex 0\n",
"Age 177\n",
"SibSp 0\n",
"Parch 0\n",
"Ticket 0\n",
"Fare 0\n",
"Cabin 687\n",
"Embarked 2\n",
"dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RdIC7lBjrsQw",
"colab_type": "code",
"colab": {}
},
"source": [
"data = data.drop(['Name','Ticket', 'Cabin', 'PassengerId'], axis=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "1y2VEBrWrsQ0",
"colab_type": "text"
},
"source": [
"### Перекодирование категориальных признаков"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GCVMK48rrsQ2",
"colab_type": "code",
"outputId": "2e028819-44fa-45b4-e2a4-8e79a5f55ce5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"data['Sex'].value_counts()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"male 577\n",
"female 314\n",
"Name: Sex, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SzdDN58KrsQ-",
"colab_type": "code",
"colab": {}
},
"source": [
"sex_mapping = {'male':0,'female':1}\n",
"data['Sex'] = data['Sex'].map(sex_mapping)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1Pxf7VtursRC",
"colab_type": "code",
"outputId": "f4d4f915-e569-4e3d-fda9-1a539a1233a5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"data.head(10)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.4583</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>51.8625</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>21.0750</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>11.1333</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>30.0708</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare Embarked\n",
"0 0 3 0 22.0 1 0 7.2500 S\n",
"1 1 1 1 38.0 1 0 71.2833 C\n",
"2 1 3 1 26.0 0 0 7.9250 S\n",
"3 1 1 1 35.0 1 0 53.1000 S\n",
"4 0 3 0 35.0 0 0 8.0500 S\n",
"5 0 3 0 NaN 0 0 8.4583 Q\n",
"6 0 1 0 54.0 0 0 51.8625 S\n",
"7 0 3 0 2.0 3 1 21.0750 S\n",
"8 1 3 1 27.0 0 2 11.1333 S\n",
"9 1 2 1 14.0 1 0 30.0708 C"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TSMZ89HDrsRH",
"colab_type": "code",
"outputId": "e02d13fd-ac98-42d5-9f60-027334c7e56e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
}
},
"source": [
"data['Embarked'].value_counts()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"S 644\n",
"C 168\n",
"Q 77\n",
"Name: Embarked, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "e6DhvwUhrsRM",
"colab_type": "code",
"colab": {}
},
"source": [
"Embarked_dummies = pd.get_dummies(data[\"Embarked\"], prefix=\"port\", dummy_na=False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4Zl4xJLIrsRR",
"colab_type": "code",
"outputId": "eb74e25f-6a46-4520-a12e-4053e99a3b74",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
}
},
"source": [
"Embarked_dummies.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>port_C</th>\n",
" <th>port_Q</th>\n",
" <th>port_S</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" port_C port_Q port_S\n",
"0 0 0 1\n",
"1 1 0 0\n",
"2 0 0 1\n",
"3 0 0 1\n",
"4 0 0 1"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Omv3fCOZrsRV",
"colab_type": "code",
"colab": {}
},
"source": [
"data = pd.concat([data, Embarked_dummies], axis=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rZv-dTcarsRY",
"colab_type": "code",
"colab": {}
},
"source": [
"data = data.drop(['Embarked'], axis=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "POyeRHNIrsRe",
"colab_type": "code",
"outputId": "d7f06711-c8a0-4889-8af8-5562701efb6b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
}
},
"source": [
"data.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" <th>port_C</th>\n",
" <th>port_Q</th>\n",
" <th>port_S</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare port_C port_Q port_S\n",
"0 0 3 0 22.0 1 0 7.2500 0 0 1\n",
"1 1 1 1 38.0 1 0 71.2833 1 0 0\n",
"2 1 3 1 26.0 0 0 7.9250 0 0 1\n",
"3 1 1 1 35.0 1 0 53.1000 0 0 1\n",
"4 0 3 0 35.0 0 0 8.0500 0 0 1"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vs5ZMyGRrsRk",
"colab_type": "code",
"colab": {}
},
"source": [
"X = data\n",
"X = X.drop(['Survived'], axis=1)\n",
"y = data['Survived']"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "y2VpBXIersRo",
"colab_type": "code",
"outputId": "1773c061-879c-4f1e-e846-7230f5c9ade7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"X.head(10)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pclass</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" <th>port_C</th>\n",
" <th>port_Q</th>\n",
" <th>port_S</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.4583</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>51.8625</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>21.0750</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>11.1333</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>30.0708</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pclass Sex Age SibSp Parch Fare port_C port_Q port_S\n",
"0 3 0 22.0 1 0 7.2500 0 0 1\n",
"1 1 1 38.0 1 0 71.2833 1 0 0\n",
"2 3 1 26.0 0 0 7.9250 0 0 1\n",
"3 1 1 35.0 1 0 53.1000 0 0 1\n",
"4 3 0 35.0 0 0 8.0500 0 0 1\n",
"5 3 0 NaN 0 0 8.4583 0 1 0\n",
"6 1 0 54.0 0 0 51.8625 0 0 1\n",
"7 3 0 2.0 3 1 21.0750 0 0 1\n",
"8 3 1 27.0 0 2 11.1333 0 0 1\n",
"9 2 1 14.0 1 0 30.0708 1 0 0"
]
},
"metadata": {
"tags": []
},
"execution_count": 19
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JZmhqOD4rsRt",
"colab_type": "text"
},
"source": [
"### Построение деревьев решений"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8xmSY0K5rsRv",
"colab_type": "code",
"outputId": "456d3f1d-ebbb-437b-9ed5-e9a45b9cad8d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 188
}
},
"source": [
"X.isnull().sum()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Pclass 0\n",
"Sex 0\n",
"Age 0\n",
"SibSp 0\n",
"Parch 0\n",
"Fare 0\n",
"port_C 0\n",
"port_Q 0\n",
"port_S 0\n",
"dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6FTFMYbtrsRy",
"colab_type": "code",
"colab": {}
},
"source": [
"X['Age'] = X['Age'].fillna(value=X['Age'].mean())"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "g4W9-KNKrsR3",
"colab_type": "code",
"colab": {}
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HFTk3NUf2P7Y",
"colab_type": "code",
"outputId": "0f390be0-2f53-4de3-ab68-aaf25397f184",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"len(X_train)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"596"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "W2slve6m2TIH",
"colab_type": "code",
"outputId": "e02beba6-2a8f-482e-f9da-bca4704edd6c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"len(X_test)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"295"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AMAIvF7rrsSB",
"colab_type": "code",
"outputId": "4ffd7cc5-e514-4d05-e5e5-530ecd1d3f4a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 120
}
},
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"model_tree = DecisionTreeClassifier(max_depth=3)\n",
"model_tree.fit(X_train, y_train)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,\n",
" max_features=None, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, presort=False,\n",
" random_state=None, splitter='best')"
]
},
"metadata": {
"tags": []
},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EyzCkeYirsSI",
"colab_type": "code",
"colab": {}
},
"source": [
"# используем .dot формат для визуализации дерева\n",
"from sklearn.tree import export_graphviz\n",
"export_graphviz(model_tree, feature_names= ['Pclassn','Sex','Age', 'SibSp','Parch','Fare','port_C', 'port_Q','port_S'], \n",
"out_file='tree.dot', filled=True)\n",
"# для этого понадобится библиотека pydot (pip install pydot)\n",
"!dot -Tpng 'tree.dot' -o 'tree.png'"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "UFHSdD453Mjw",
"colab_type": "code",
"colab": {}
},
"source": [
"files.download('tree.png')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dLekdbKK28pG",
"colab_type": "code",
"outputId": "0ff8088f-9d78-4701-8d0e-4afd05c44359",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"!ls"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"sample_data train.csv\ttree.dot tree.png\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "U2ohXJbbrsSX",
"colab_type": "code",
"colab": {}
},
"source": [
"y_predict = model_tree.predict(X_test)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jrlHcuMOrsSa",
"colab_type": "code",
"outputId": "62c195aa-e878-4b87-e8e2-da9a3744caa8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"from sklearn.metrics import accuracy_score\n",
"print(accuracy_score(y_test, y_predict))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"0.8203389830508474\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment