Created March 30, 2020
Sklearn and pandas example
"cells": [
"cell_type": "markdown",
"source": [
### Обработка категориальных признаков
"cell_type": "code",
"source": [
"from IPython.display import Image\n",
"Image(url= \"\")"
"cell_type": "code",
"source": [
"from google.colab import files\n",
"cell_type": "code",
"source": [
"import pandas as pd\n",
"data = pd.read_csv(\"train.csv\")"
"cell_type": "code",
"source": [
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Moran, Mr. James</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>330877</td>\n",
" <td>8.4583</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>McCarthy, Mr. Timothy J</td>\n",
" <td>male</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>17463</td>\n",
" <td>51.8625</td>\n",
" <td>E46</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Palsson, Master. Gosta Leonard</td>\n",
" <td>male</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>349909</td>\n",
" <td>21.0750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n",
" <td>female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>347742</td>\n",
" <td>11.1333</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n",
" <td>female</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>237736</td>\n",
" <td>30.0708</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" PassengerId Survived Pclass ... Fare Cabin Embarked\n",
"0 1 0 3 ... 7.2500 NaN S\n",
"1 2 1 1 ... 71.2833 C85 C\n",
"2 3 1 3 ... 7.9250 NaN S\n",
"3 4 1 1 ... 53.1000 C123 S\n",
"4 5 0 3 ... 8.0500 NaN S\n",
"5 6 0 3 ... 8.4583 NaN Q\n",
"6 7 0 1 ... 51.8625 E46 S\n",
"7 8 0 3 ... 21.0750 NaN S\n",
"8 9 1 3 ... 11.1333 NaN S\n",
"9 10 1 2 ... 30.0708 NaN C\n",
"[10 rows x 12 columns]"
"cell_type": "code",
"source": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 891 entries, 0 to 890\n",
"Data columns (total 12 columns):\n",
"PassengerId 891 non-null int64\n",
"Survived 891 non-null int64\n",
"Pclass 891 non-null int64\n",
"Name 891 non-null object\n",
"Sex 891 non-null object\n",
"Age 714 non-null float64\n",
"SibSp 891 non-null int64\n",
"Parch 891 non-null int64\n",
"Ticket 891 non-null object\n",
"Fare 891 non-null float64\n",
"Cabin 204 non-null object\n",
"Embarked 889 non-null object\n",
"dtypes: float64(2), int64(5), object(5)\n",
"memory usage: 83.7+ KB\n"
"cell_type": "code",
"source": [
"PassengerId 0\n",
"Survived 0\n",
"Pclass 0\n",
"Name 0\n",
"Sex 0\n",
"Age 177\n",
"SibSp 0\n",
"Parch 0\n",
"Ticket 0\n",
"Fare 0\n",
"Cabin 687\n",
"Embarked 2\n",
"dtype: int64"
"cell_type": "code",
"source": [
"data = data.drop(['Name','Ticket', 'Cabin', 'PassengerId'], axis=1)"
"cell_type": "markdown",
"source": [
"### Перекодирование категориальных признаков"
"cell_type": "code",
"source": [
"male 577\n",
"female 314\n",
"Name: Sex, dtype: int64"
"cell_type": "code",
"source": [
"sex_mapping = {'male':0,'female':1}\n",
"data['Sex'] = data['Sex'].map(sex_mapping)"
"cell_type": "code",
"source": [
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.4583</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>51.8625</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>21.0750</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>11.1333</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>30.0708</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare Embarked\n",
"0 0 3 0 22.0 1 0 7.2500 S\n",
"1 1 1 1 38.0 1 0 71.2833 C\n",
"2 1 3 1 26.0 0 0 7.9250 S\n",
"3 1 1 1 35.0 1 0 53.1000 S\n",
"4 0 3 0 35.0 0 0 8.0500 S\n",
"5 0 3 0 NaN 0 0 8.4583 Q\n",
"6 0 1 0 54.0 0 0 51.8625 S\n",
"7 0 3 0 2.0 3 1 21.0750 S\n",
"8 1 3 1 27.0 0 2 11.1333 S\n",
"9 1 2 1 14.0 1 0 30.0708 C"
"cell_type": "code",
"source": [
"S 644\n",
"C 168\n",
"Q 77\n",
"Name: Embarked, dtype: int64"
"cell_type": "code",
"source": [
"Embarked_dummies = pd.get_dummies(data[\"Embarked\"], prefix=\"port\", dummy_na=False)"
"cell_type": "code",
"source": [
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" port_C port_Q port_S\n",
"0 0 0 1\n",
"1 1 0 0\n",
"2 0 0 1\n",
"3 0 0 1\n",
"4 0 0 1"
"cell_type": "code",
"source": [
"data = pd.concat([data, Embarked_dummies], axis=1)"
"cell_type": "code",
"source": [
"data = data.drop(['Embarked'], axis=1)"
"cell_type": "code",
"source": [
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare port_C port_Q port_S\n",
"0 0 3 0 22.0 1 0 7.2500 0 0 1\n",
"1 1 1 1 38.0 1 0 71.2833 1 0 0\n",
"2 1 3 1 26.0 0 0 7.9250 0 0 1\n",
"3 1 1 1 35.0 1 0 53.1000 0 0 1\n",
"4 0 3 0 35.0 0 0 8.0500 0 0 1"
"cell_type": "code",
"source": [
"X = data\n",
"X = X.drop(['Survived'], axis=1)\n",
"y = data['Survived']"
"cell_type": "code",
"source": [
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.4583</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>51.8625</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>21.0750</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>11.1333</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>30.0708</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" Pclass Sex Age SibSp Parch Fare port_C port_Q port_S\n",
"0 3 0 22.0 1 0 7.2500 0 0 1\n",
"1 1 1 38.0 1 0 71.2833 1 0 0\n",
"2 3 1 26.0 0 0 7.9250 0 0 1\n",
"3 1 1 35.0 1 0 53.1000 0 0 1\n",
"4 3 0 35.0 0 0 8.0500 0 0 1\n",
"5 3 0 NaN 0 0 8.4583 0 1 0\n",
"6 1 0 54.0 0 0 51.8625 0 0 1\n",
"7 3 0 2.0 3 1 21.0750 0 0 1\n",
"8 3 1 27.0 0 2 11.1333 0 0 1\n",
"9 2 1 14.0 1 0 30.0708 1 0 0"
"cell_type": "markdown",
"source": [
"### Построение деревьев решений"
"cell_type": "code",
"source": [
"Pclass 0\n",
"Sex 0\n",
"Age 0\n",
"SibSp 0\n",
"Parch 0\n",
"Fare 0\n",
"port_C 0\n",
"port_Q 0\n",
"port_S 0\n",
"dtype: int64"
"cell_type": "code",
"source": [
"X['Age'] = X['Age'].fillna(value=X['Age'].mean())"
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
"cell_type": "code",
"source": [
"text/plain": [
"execution_count": 24
"source": [
"text/plain": [
"execution_count": 25
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"model_tree = DecisionTreeClassifier(max_depth=3)\n",
", y_train)"
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,\n",
" max_features=None, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, presort=False,\n",
" random_state=None, splitter='best')"
"cell_type": "code",
"source": [
"# используем .dot формат для визуализации дерева\n",
"from sklearn.tree import export_graphviz\n",
"export_graphviz(model_tree, feature_names= ['Pclassn','Sex','Age', 'SibSp','Parch','Fare','port_C', 'port_Q','port_S'], \n",
"out_file='', filled=True)\n",
"# для этого понадобится библиотека pydot (pip install pydot)\n",
"!dot -Tpng '' -o 'tree.png'"
"cell_type": "code",
"source": [
"cell_type": "code",
"source": [
"sample_data train.csv\ tree.png\n"
"cell_type": "code",
"source": [
"y_predict = model_tree.predict(X_test)"
"cell_type": "code",
"source": [
"from sklearn.metrics import accuracy_score\n",
"print(accuracy_score(y_test, y_predict))"
