Skip to content

Instantly share code, notes, and snippets.

@XinyueZ
Created August 22, 2022 11:34
Show Gist options
  • Save XinyueZ/1dc6dbf6f6978c315a1073700d11bf9d to your computer and use it in GitHub Desktop.
Save XinyueZ/1dc6dbf6f6978c315a1073700d11bf9d to your computer and use it in GitHub Desktop.
notebook.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/XinyueZ/1dc6dbf6f6978c315a1073700d11bf9d/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3OGDXVVxU-qw"
},
"source": [
"Comparing Decision Trees, Bagging, Random Forest and Neural Networks performance\n",
"==\n",
"\n",
"Build all models with two different dataset for multiclass classification. Machine learning experiments are somewhat randomized, and the purpose of this experiment is to build on 2 datasets consisting mainly of \"category\" features. The experimental results are considered as a reference value.\n",
"\n",
"Model:\n",
"\n",
"- Decision tree (`fit_DT()`) \n",
"- Bagging (`fit_bagging()`) \n",
"- Random Forest (`fit_RF()`) \n",
"- Neural network with 3-hidden layers (`fit_nn()`) \n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f5eexd70VVoz"
},
"source": [
"> Notebook: https://colab.research.google.com/drive/1yA5j8AHcLfy_Vr3pGSlGVwj07fnF5hXW?usp=sharing\n",
">\n",
"> Contact: chris.at.de@gmail.com"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dRfdYorMVlUU"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import metrics\n",
"from sklearn import preprocessing\n",
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras \n",
"\n",
"\n",
"\n",
"np.random.seed(1024)\n",
"tf.random.set_seed(1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qeXgvw_62MtQ"
},
"source": [
"## Metric auxiliary "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o9Xllml4W5Wd"
},
"outputs": [],
"source": [
"def get_accuracy_nn(X_train, X_test, y_train, y_test, model):\n",
" return pd.Series({\\\n",
" \"test Accuracy\" : round(metrics.accuracy_score(y_test, np.argmax(model.predict(X_test ), axis=1)),3),\\\n",
" \"train Accuracy\": round(metrics.accuracy_score(y_train, np.argmax(model.predict(X_train), axis=1)),3)}, name=\"NN\")\n",
"\n",
"def get_accuracy(X_train, X_test, y_train, y_test, model, name):\n",
" return pd.Series({\\\n",
" \"test Accuracy\":round(metrics.accuracy_score(y_test, model.predict(X_test)),3),\\\n",
" \"train Accuracy\": round(metrics.accuracy_score(y_train, model.predict(X_train)),3)}, name=name)\n",
" \n",
"def print_plot_result(*res):\n",
" res_df = pd.concat(list(res), axis=1)\n",
" display(res_df)\n",
" res_df.plot.bar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UaxUAyJJy3pO"
},
"source": [
"## Models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "db-Q0lr1y6RT"
},
"source": [
"### DT\n",
"\n",
"Decision trees split your data using impurity(不纯物) measures. They are a greedy algorithm and are not based on statistical assumptions. The most common splitting impurity measures are Entropy and Gini index. Decision trees tend to overfit and to be very sensitive to different data(bad on new data).\n",
"Cross validation and pruning sometimes help with some of this.\n",
"\n",
"Great advantages of decision trees are that they are really:\n",
"- Easy to interpret \n",
"- require NO data preprocessing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RoSVeyPzzGNb"
},
"outputs": [],
"source": [
"def fit_DT(X_train, X_test, y_train, y_test):\n",
" tree_classifier = DecisionTreeClassifier().fit(X_train, y_train) \n",
"\n",
" param_grid = {'criterion' : [\"entropy\", \"gini\"],\n",
" 'max_depth':range(1, tree_classifier.tree_.max_depth+1, 2),\n",
" 'max_features': range(1, len(tree_classifier.feature_importances_)+1)}\n",
"\n",
" search = GridSearchCV(DecisionTreeClassifier(random_state=42),\n",
" param_grid=param_grid,\n",
" scoring='accuracy',\n",
" n_jobs=-1)\n",
" search.fit(X_train, y_train)\n",
" return get_accuracy(X_train, X_test, y_train, y_test, search.best_estimator_, \"DT\"), tree_classifier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqsirX6ky8l9"
},
"source": [
"### Bagging\n",
"\n",
" \n",
"A model that **averages the predictions of multiple models** (ie. decision tree) reduces the variance of a single model and has high chances to generalize well when scoring new data. Bagging is a tree ensemble that combines the prediction of several trees that were trained on bootstrap samples of the data.\n",
"\n",
"Bagging, which combines decision trees by using bootstrap aggregated samples.\n",
"\n",
"We could perform a bootstrap sample using the function `resample`; we see the dataset is the same size, but some rows are repeated\n",
"\n",
"```\n",
"from sklearn.utils import resample\n",
" \n",
"df # some dataframe\n",
"\n",
"for n in range(5):\n",
" sample = resample(df[0:5])\n",
" print(sample)\n",
" print(\"len \", len(sample))\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4FH3gKnyzJhk"
},
"outputs": [],
"source": [
"def fit_bagging(base_tree, X_train, X_test, y_train, y_test):\n",
" param_grid = {'n_estimators': [2*n+1 for n in range(20)]}\n",
" bagging= BaggingClassifier(base_estimator=base_tree,random_state=0,bootstrap=True)\n",
" search = GridSearchCV(estimator=bagging, param_grid=param_grid, scoring=\"accuracy\", cv=3)\n",
" search.fit(X_train, y_train)\n",
" return get_accuracy(X_train, X_test, y_train, y_test, search.best_estimator_, name=\"Bagging\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aGjips_Zy_4R"
},
"source": [
"### RF\n",
"\n",
"is essentially the bagging. So bootstrapping and aggregating with not only the subset of the rows being random, but also the subset of the features or columns also being random. (In general, a random forest can be considered a special case of bagging and it tends to have better out of sample accuracy.)\n",
"\n",
"Random forests are **a combination of trees** such that **each tree depends on a random subset of the features and data**. As a result, each tree in the forest is different and usually performs better than Bagging. The most important parameters are **the number of trees** and **the number of features** to sample. First, we import RandomForestClassifier.\n",
"Like Bagging, increasing the number of trees improves results and does not lead to overfitting in most cases; but the improvements plateau as you add more trees. For this exxample, the number of trees in the forest (default=100).\n",
"\n",
"\n",
"Like Bagging, RF uses an independent bootstrap sample from the training data. In addition, we select 𝑚 variables at random out of all 𝑀 possible variables. Let's do an example.\n",
"\n",
"We now randomly select features from the bootstrap samples, in randomly selecting a subset of the features for each node to split on.\n",
"```\n",
"import random\n",
"\n",
"X=df # some dataframe\n",
"M=X.shape[1]\n",
"feature_index= range(M)\n",
"\n",
"random.sample(feature_index, m)\n",
"\n",
"for n in range(5):\n",
" print(\"sample {}\".format(n))\n",
" print(resample(X[0:5]).iloc[:,random.sample(feature_index,m)])\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a4YU4xnZzNDz"
},
"outputs": [],
"source": [
"def fit_RF(X_train, X_test, y_train, y_test):\n",
" param_grid = {'n_estimators': [2*n+1 for n in range(20)],\n",
" 'max_depth' : [2*n+1 for n in range(10) ],\n",
" 'max_features':[\"auto\", \"sqrt\", \"log2\"]}\n",
" RF = RandomForestClassifier()\n",
" search = GridSearchCV(estimator=RF, param_grid=param_grid, scoring=\"accuracy\", cv=3)\n",
" search.fit(X_train, y_train)\n",
" \n",
" return get_accuracy(X_train, X_test, y_train, y_test, search.best_estimator_, name=\"RF\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tWId8qRszCmT"
},
"source": [
"### NN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RVyJF6aHzTqk"
},
"outputs": [],
"source": [
"def fit_nn(X_train_, X_test_, y_train_, y_test_):\n",
" input_shape = X_train_.shape\n",
"\n",
" n = input_shape[1]\n",
" n_classes = len(y.unique())\n",
"\n",
" inputs = keras.Input(shape=[n])\n",
" x = keras.layers.Dense(units=n**2, activation=keras.activations.relu)(inputs)\n",
" x = keras.layers.Dense(units=n**2**2, activation=keras.activations.relu)(x)\n",
" x = keras.layers.Dense(units=n**2, activation=keras.activations.relu)(x)\n",
" outputs = keras.layers.Dense(units=n_classes, activation=keras.activations.softmax)(x)\n",
"\n",
" model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n",
"\n",
" model.compile(optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
"\n",
" model.fit(epochs=50, x=X_train_, y=y_train_, shuffle=True, batch_size=8, verbose=0)\n",
"\n",
" return get_accuracy_nn(X_train_, X_test_, y_train_, y_test_, model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ydl-2oueVQPN"
},
"source": [
"# Dataset 1\n",
"\n",
"Drug200.csv\n",
"\n",
"A course of treatment, each patient responded to one of 5 medications: \n",
"\n",
"Drug A, Drug B, Drug c, Drug x and y.\n",
"\n",
"Build a model to find out which drug might be appropriate for a future patient with the same illness.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "DotZSnAoU62-",
"outputId": "e157ae28-025d-4976-a8db-fdbd5110a97a"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div id=\"df-c87c2bcc-f346-43c6-b011-78cdf53aaddc\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age</th>\n",
" <th>Sex</th>\n",
" <th>BP</th>\n",
" <th>Cholesterol</th>\n",
" <th>Na_to_K</th>\n",
" <th>Drug</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>23</td>\n",
" <td>F</td>\n",
" <td>HIGH</td>\n",
" <td>HIGH</td>\n",
" <td>25.355</td>\n",
" <td>drugY</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>47</td>\n",
" <td>M</td>\n",
" <td>LOW</td>\n",
" <td>HIGH</td>\n",
" <td>13.093</td>\n",
" <td>drugC</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>47</td>\n",
" <td>M</td>\n",
" <td>LOW</td>\n",
" <td>HIGH</td>\n",
" <td>10.114</td>\n",
" <td>drugC</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>28</td>\n",
" <td>F</td>\n",
" <td>NORMAL</td>\n",
" <td>HIGH</td>\n",
" <td>7.798</td>\n",
" <td>drugX</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>61</td>\n",
" <td>F</td>\n",
" <td>LOW</td>\n",
" <td>HIGH</td>\n",
" <td>18.043</td>\n",
" <td>drugY</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c87c2bcc-f346-43c6-b011-78cdf53aaddc')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-c87c2bcc-f346-43c6-b011-78cdf53aaddc button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-c87c2bcc-f346-43c6-b011-78cdf53aaddc');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" Age Sex BP Cholesterol Na_to_K Drug\n",
"0 23 F HIGH HIGH 25.355 drugY\n",
"1 47 M LOW HIGH 13.093 drugC\n",
"2 47 M LOW HIGH 10.114 drugC\n",
"3 28 F NORMAL HIGH 7.798 drugX\n",
"4 61 F LOW HIGH 18.043 drugY"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"https://dl.dropbox.com/s/hglnf8sdlhcyinm/drug200.csv\", delimiter=\",\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3_rV6v6qVym7",
"outputId": "161fa0cb-1b5a-492b-d89a-5d1984256088"
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[23, 'F', 'HIGH', 'HIGH', 25.355],\n",
" [47, 'M', 'LOW', 'HIGH', 13.093],\n",
" [47, 'M', 'LOW', 'HIGH', 10.114],\n",
" [28, 'F', 'NORMAL', 'HIGH', 7.798]], dtype=object), 0 drugY\n",
" 1 drugC\n",
" 2 drugC\n",
" 3 drugX\n",
" Name: Drug, dtype: object)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = df[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values\n",
"y = df[\"Drug\"]\n",
"X[:4], y[:4]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrhyV99wWQ5p"
},
"source": [
"## Preprocessing and train, test split"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uPpi0d9HV26a",
"outputId": "a0cfb5ab-a60c-4d55-dbde-39c47d989e49"
},
"outputs": [
{
"data": {
"text/plain": [
"array([[23, 0, 0, 0, 25.355],\n",
" [47, 1, 1, 0, 13.093],\n",
" [47, 1, 1, 0, 10.114],\n",
" [28, 0, 2, 0, 7.798]], dtype=object)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"le_sex = preprocessing.LabelEncoder()\n",
"le_sex.fit(['F','M'])\n",
"X[:,1] = le_sex.transform(X[:,1]) \n",
"\n",
"\n",
"le_BP = preprocessing.LabelEncoder()\n",
"le_BP.fit([ 'LOW', 'NORMAL', 'HIGH'])\n",
"X[:,2] = le_BP.transform(X[:,2])\n",
"\n",
"\n",
"le_Chol = preprocessing.LabelEncoder()\n",
"le_Chol.fit([ 'NORMAL', 'HIGH'])\n",
"X[:,3] = le_Chol.transform(X[:,3]) \n",
"\n",
"X[0:4]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9dP6WABX-kb"
},
"source": [
"### Dataset for RF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Kfos2vlvWebO",
"outputId": "9ae9d114-c2f8-4f35-c6dc-635ecc9250ea"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train set: (160, 5) (160,)\n",
"Test set: (40, 5) (40,)\n"
]
}
],
"source": [
"X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=4)\n",
"print ('Train set:', X_train.shape, y_train.shape)\n",
"print ('Test set:', X_test.shape, y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8fT0-RmYYCTY"
},
"source": [
"### Dataset for NN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5uk46vkhX8Uc"
},
"outputs": [],
"source": [
"transformer = preprocessing.StandardScaler()\n",
"X_train_ = transformer.fit_transform(X_train)\n",
"X_test_ = transformer.transform(X_test)\n",
"y_train_ = preprocessing.LabelEncoder().fit(y.unique().tolist()).transform(y_train) \n",
"y_test_ = preprocessing.LabelEncoder().fit(y.unique().tolist()).transform(y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gDbb6XuItr5O"
},
"source": [
"## Fit DT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bNw07FQ_tyEO"
},
"outputs": [],
"source": [
"DT_1, tree_1 = fit_DT(X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X0U_LJtHsUze"
},
"source": [
"## Fit Bagging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ydmRn8PFszhU"
},
"outputs": [],
"source": [
"Bagging_1 = fit_bagging(tree_1, X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TSk1UTGKWtVV"
},
"source": [
"## Fit RF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ii2deqsuWvU_"
},
"outputs": [],
"source": [
"RF_1 = fit_RF(X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C0x600POX1lI"
},
"source": [
"## Fit NN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L-mYaDrjYYd7"
},
"outputs": [],
"source": [
"NN_1 = fit_nn(X_train_, X_test_, y_train_, y_test_)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o-2-DBz1ZEXr"
},
"source": [
"## Result\n",
"\n",
"In this experiment (test set), the best performing \"tree\" algorithm is Bagging, at a level similar to NN."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "THE32m-wZFoe",
"outputId": "19ee55d8-c11b-4ae3-9897-ddefdbbfa3df"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div id=\"df-ad1a0553-293f-446e-8da3-5f2db746d531\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>DT</th>\n",
" <th>Bagging</th>\n",
" <th>RF</th>\n",
" <th>NN</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>test Accuracy</th>\n",
" <td>0.825</td>\n",
" <td>0.95</td>\n",
" <td>0.925</td>\n",
" <td>0.95</td>\n",
" </tr>\n",
" <tr>\n",
" <th>train Accuracy</th>\n",
" <td>0.931</td>\n",
" <td>1.00</td>\n",
" <td>1.000</td>\n",
" <td>1.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ad1a0553-293f-446e-8da3-5f2db746d531')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-ad1a0553-293f-446e-8da3-5f2db746d531 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-ad1a0553-293f-446e-8da3-5f2db746d531');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" DT Bagging RF NN\n",
"test Accuracy 0.825 0.95 0.925 0.95\n",
"train Accuracy 0.931 1.00 1.000 1.00"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_plot_result(DT_1, Bagging_1, RF_1, NN_1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fjfcriC3aLY2"
},
"source": [
"# Dataset 2\n",
"\n",
"From the UCI Machine Learning Repository (Asuncion and Newman, 2007)\\[[http://mlearn.ics.uci.edu/MLRepository.html](http://mlearn.ics.uci.edu/MLRepository.html?utm_medium=Exinfluencer&utm_source=Exinfluencer&utm_content=000026UJ&utm_term=10006555&utm_id=NA-SkillsNetwork-Channel-SkillsNetworkCoursesIBMML241ENSkillsNetwork31576874-2022-01-01)]. \n",
"\n",
"The dataset consists of hundred human cell sample records, each of which contains the values of a set of cell characteristics. \n",
"\n",
"| Field name | Description |\n",
"| ----------- | --------------------------- |\n",
"| ID | Clump thickness |\n",
"| Clump | Clump thickness |\n",
"| UnifSize | Uniformity of cell size |\n",
"| UnifShape | Uniformity of cell shape |\n",
"| MargAdh | Marginal adhesion |\n",
"| SingEpiSize | Single epithelial cell size |\n",
"| BareNuc | Bare nuclei |\n",
"| BlandChrom | Bland chromatin |\n",
"| NormNucl | Normal nucleoli |\n",
"| Mit | Mitoses |\n",
"| Class | Benign or malignant |\n",
"\n",
"<br>\n",
"<br> "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "deOYAZ4lakNi",
"outputId": "b1ae070d-60c7-49ed-bcd6-804c545cadb5"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div id=\"df-f512fbec-adb2-4c78-ae09-94d37b357599\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Clump</th>\n",
" <th>UnifSize</th>\n",
" <th>UnifShape</th>\n",
" <th>MargAdh</th>\n",
" <th>SingEpiSize</th>\n",
" <th>BareNuc</th>\n",
" <th>BlandChrom</th>\n",
" <th>NormNucl</th>\n",
" <th>Mit</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1000025</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1002945</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>7</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1015425</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1016277</td>\n",
" <td>6</td>\n",
" <td>8</td>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1017023</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f512fbec-adb2-4c78-ae09-94d37b357599')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-f512fbec-adb2-4c78-ae09-94d37b357599 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-f512fbec-adb2-4c78-ae09-94d37b357599');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" ID Clump UnifSize UnifShape MargAdh SingEpiSize BareNuc \\\n",
"0 1000025 5 1 1 1 2 1 \n",
"1 1002945 5 4 4 5 7 10 \n",
"2 1015425 3 1 1 1 2 2 \n",
"3 1016277 6 8 8 1 3 4 \n",
"4 1017023 4 1 1 3 2 1 \n",
"\n",
" BlandChrom NormNucl Mit Class \n",
"0 3 1 1 2 \n",
"1 3 2 1 2 \n",
"2 3 1 1 2 \n",
"3 3 7 1 2 \n",
"4 3 1 1 2 "
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"https://dl.dropbox.com/s/s4j1wkgz9xl7f7h/cell_samples.csv\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mpis21VYaxkm"
},
"outputs": [],
"source": [
"df = df[pd.to_numeric(df['BareNuc'], errors='coerce').notnull()] #remove rows that have a ? in the BareNuc column\n",
"\n",
"X = df[['Clump', 'UnifSize', 'UnifShape', 'MargAdh', 'SingEpiSize', 'BareNuc', 'BlandChrom', 'NormNucl', 'Mit']]\n",
"y = df['Class']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7E8v8_Fxa7Oj"
},
"source": [
"## Preprocessing and train, test split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BtgHQzshbAnT"
},
"source": [
"### Dataset for RF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "W0EOzoGUbCTy",
"outputId": "4691ed2c-1f22-4103-96d0-f621ac4d4bfc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train set: (546, 9) (546,)\n",
"Test set: (137, 9) (137,)\n"
]
}
],
"source": [
"X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=4)\n",
"print ('Train set:', X_train.shape, y_train.shape)\n",
"print ('Test set:', X_test.shape, y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nzp2jajXbEMP"
},
"source": [
"### Dataset for NN\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1q6IlwhLbRPy"
},
"outputs": [],
"source": [
"transformer = preprocessing.StandardScaler()\n",
"X_train_ = transformer.fit_transform(X_train)\n",
"X_test_ = transformer.transform(X_test)\n",
"y_train_ = preprocessing.LabelEncoder().fit(y.unique().tolist()).transform(y_train) \n",
"y_test_ = preprocessing.LabelEncoder().fit(y.unique().tolist()).transform(y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zx8gWBaSuWM8"
},
"source": [
"## Fit DT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "odONkVlxuZ1O"
},
"outputs": [],
"source": [
"DT_2, tree_2 = fit_DT(X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tQjTBa-YuWBT"
},
"source": [
"## Fit Bagging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4Mxtp2l0wEI3"
},
"outputs": [],
"source": [
"Bagging_2=fit_bagging(tree_2, X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YAgt3mxJbcMy"
},
"source": [
"## Fit RF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m34ldpWFbhnP"
},
"outputs": [],
"source": [
"RF_2 = fit_RF(X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "811JyJUTbdd1"
},
"source": [
"## Fit NN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vq2cfoS5bktL"
},
"outputs": [],
"source": [
"NN_2 = fit_nn(X_train_, X_test_, y_train_, y_test_)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lEw5vhaIcEZ3"
},
"source": [
"## Result\n",
"\n",
"In this experiment (test set), the best performing \"tree\" algorithm is normal decision tree. It performs even better than RF and NN. \n",
"\n",
"Both RF and NN are at a level which is similar."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "W1U0gCpdcHOv",
"outputId": "4ed65e83-2678-4b5a-aaad-3f9f26bd20fb"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div id=\"df-6bda29f7-6497-454d-89a8-0601dea89e27\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>DT</th>\n",
" <th>Bagging</th>\n",
" <th>RF</th>\n",
" <th>NN</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>test Accuracy</th>\n",
" <td>0.978</td>\n",
" <td>0.956</td>\n",
" <td>0.971</td>\n",
" <td>0.971</td>\n",
" </tr>\n",
" <tr>\n",
" <th>train Accuracy</th>\n",
" <td>0.985</td>\n",
" <td>0.993</td>\n",
" <td>0.984</td>\n",
" <td>1.000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-6bda29f7-6497-454d-89a8-0601dea89e27')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-6bda29f7-6497-454d-89a8-0601dea89e27 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-6bda29f7-6497-454d-89a8-0601dea89e27');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" DT Bagging RF NN\n",
"test Accuracy 0.978 0.956 0.971 0.971\n",
"train Accuracy 0.985 0.993 0.984 1.000"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_plot_result(DT_2, Bagging_2, RF_2, NN_2)"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "notebook.ipynb",
"provenance": [],
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vp": {
"vp_config_version": "1.0.0",
"vp_menu_width": 273,
"vp_note_display": false,
"vp_note_width": 0,
"vp_position": {
"width": 278
},
"vp_section_display": false,
"vp_signature": "VisualPython"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment