Skip to content

Instantly share code, notes, and snippets.

@ShahStavan
Created September 29, 2023 19:25
Show Gist options
  • Save ShahStavan/f3e8e4214e25b4483be52e6ccef30c08 to your computer and use it in GitHub Desktop.
Save ShahStavan/f3e8e4214e25b4483be52e6ccef30c08 to your computer and use it in GitHub Desktop.
22bce539_Practical_6.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyM9vnS1ihT4APKi3+fbAg13"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# ROLL NO : 22bce539\n",
"# NAME : SHAH STAVAN PURVESHBHAI\n",
"# SUBJECT & COURSE CODE: MACHINE LEARNING (2CS501)\n",
"# Practical 6 (Decision Trees)\n",
"# Date : 29/09/2023"
],
"metadata": {
"id": "byb7PjlMULel"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vr7h0g5XT123"
},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.tree import export_text\n",
"from sklearn.model_selection import train_test_split, cross_val_score\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import precision_recall_fscore_support\n",
"from sklearn.metrics import confusion_matrix\n",
"from sklearn import tree"
]
},
{
"cell_type": "code",
"source": [
"# Load the Iris dataset\n",
"irisdata = load_iris()\n",
"X = irisdata.data\n",
"y = irisdata.target\n"
],
"metadata": {
"id": "oKdwUnjnUXiZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Split the dataset into a training and testing set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)"
],
"metadata": {
"id": "Icuh0MkOVEJy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"* ID3 Decision Tree Classifier\n",
"\n",
"\n"
],
"metadata": {
"id": "UIpSBIcWVHgZ"
}
},
{
"cell_type": "code",
"source": [
"# Initialize and train the ID3 Decision Tree Classifier\n",
"id3_classifier = DecisionTreeClassifier(criterion=\"entropy\", splitter=\"best\")\n",
"id3_classifier.fit(X_train, y_train)\n",
"\n",
"\n",
"tree.plot_tree(id3_classifier)\n",
"\n",
"# Predict using the trained classifier\n",
"id3_predictions = id3_classifier.predict(X_test)\n",
"\n",
"# Calculate accuracy\n",
"id3_accuracy = accuracy_score(y_test, id3_predictions)\n",
"print(\"ID3 Accuracy:\", id3_accuracy)\n",
"\n",
"# Print confusion matrix\n",
"print(\"\\nConfusion matrix: \")\n",
"print(confusion_matrix(y_test, id3_predictions))\n",
"\n",
"# Print Precision, Recall, F1-Score, and Support\n",
"precision, recall, f1_score, support = precision_recall_fscore_support(y_test, id3_predictions)\n",
"print(\"\\nPrecision:\", precision)\n",
"print(\"Recall:\", recall)\n",
"print(\"F1-Score:\", f1_score)\n",
"print(\"Support:\", support)\n",
"\n",
"# Print the Decision Tree in text format\n",
"draw_tree_text = export_text(id3_classifier, feature_names=irisdata['feature_names'])\n",
"print(draw_tree_text)\n",
"\n",
"# Perform cross-validation and print scores\n",
"scores = cross_val_score(id3_classifier, X, y, cv=10)\n",
"print(\"\\nCross Validation Scores:\", scores)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "Cz0TcyK9VFJM",
"outputId": "39195769-a37e-468b-8720-b843bc233a46"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"ID3 Accuracy: 0.9777777777777777\n",
"\n",
"Confusion matrix: \n",
"[[19 0 0]\n",
" [ 0 13 0]\n",
" [ 0 1 12]]\n",
"\n",
"Precision: [1. 0.92857143 1. ]\n",
"Recall: [1. 1. 0.92307692]\n",
"F1-Score: [1. 0.96296296 0.96 ]\n",
"Support: [19 13 13]\n",
"|--- petal width (cm) <= 0.80\n",
"| |--- class: 0\n",
"|--- petal width (cm) > 0.80\n",
"| |--- petal length (cm) <= 4.75\n",
"| | |--- petal width (cm) <= 1.60\n",
"| | | |--- class: 1\n",
"| | |--- petal width (cm) > 1.60\n",
"| | | |--- class: 2\n",
"| |--- petal length (cm) > 4.75\n",
"| | |--- petal length (cm) <= 5.15\n",
"| | | |--- petal width (cm) <= 1.75\n",
"| | | | |--- sepal width (cm) <= 2.35\n",
"| | | | | |--- class: 2\n",
"| | | | |--- sepal width (cm) > 2.35\n",
"| | | | | |--- petal length (cm) <= 5.05\n",
"| | | | | | |--- class: 1\n",
"| | | | | |--- petal length (cm) > 5.05\n",
"| | | | | | |--- sepal length (cm) <= 6.15\n",
"| | | | | | | |--- class: 1\n",
"| | | | | | |--- sepal length (cm) > 6.15\n",
"| | | | | | | |--- class: 2\n",
"| | | |--- petal width (cm) > 1.75\n",
"| | | | |--- sepal width (cm) <= 3.10\n",
"| | | | | |--- class: 2\n",
"| | | | |--- sepal width (cm) > 3.10\n",
"| | | | | |--- class: 1\n",
"| | |--- petal length (cm) > 5.15\n",
"| | | |--- class: 2\n",
"\n",
"\n",
"Cross Validation Scores: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 1. 1. 1. ]\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"* CART Decision Tree Classifier\n",
"\n"
],
"metadata": {
"id": "lYP1KQ8RVtnw"
}
},
{
"cell_type": "code",
"source": [
"# Create a Decision Tree classifier\n",
"cart_classifier = DecisionTreeClassifier(criterion=\"gini\", splitter=\"best\")\n",
"\n",
"# Train the classifier on the training data\n",
"classified = cart_classifier.fit(X_train, y_train)\n",
"\n",
"# Make predictions on the test data\n",
"cart_predictions = cart_classifier.predict(X_test)\n",
"\n",
"# Calculate accuracy\n",
"cart_accuracy = accuracy_score(y_test, cart_predictions)\n",
"print(\"CART Accuracy:\", cart_accuracy)\n",
"\n",
"# Print confusion matrix\n",
"print(\"\\nConfusion matrix: \")\n",
"print(confusion_matrix(y_test, cart_predictions))\n",
"\n",
"# Print precision, recall, F1-Score, and Support\n",
"precision, recall, f1_score, support = precision_recall_fscore_support(y_test, cart_predictions)\n",
"print(\"\\nPrecision:\", precision)\n",
"print(\"Recall:\", recall)\n",
"print(\"F1-Score:\", f1_score)\n",
"print(\"Support:\", support, \"\\n\")\n",
"\n",
"# Export the Decision Tree as text (optional)\n",
"draw_tree_text = export_text(classified, feature_names=irisdata.feature_names)\n",
"print(draw_tree_text)\n",
"\n",
"# Perform 10-fold cross-validation and print scores\n",
"scores = cross_val_score(classified, X, y, cv=10)\n",
"print(\"Cross Validation Scores:\", scores)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nIVLrJFNVwcw",
"outputId": "829537d5-9c33-4205-9684-74d82f052e89"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CART Accuracy: 1.0\n",
"\n",
"Confusion matrix: \n",
"[[19 0 0]\n",
" [ 0 13 0]\n",
" [ 0 0 13]]\n",
"\n",
"Precision: [1. 1. 1.]\n",
"Recall: [1. 1. 1.]\n",
"F1-Score: [1. 1. 1.]\n",
"Support: [19 13 13] \n",
"\n",
"|--- petal width (cm) <= 0.80\n",
"| |--- class: 0\n",
"|--- petal width (cm) > 0.80\n",
"| |--- petal length (cm) <= 4.75\n",
"| | |--- petal width (cm) <= 1.60\n",
"| | | |--- class: 1\n",
"| | |--- petal width (cm) > 1.60\n",
"| | | |--- class: 2\n",
"| |--- petal length (cm) > 4.75\n",
"| | |--- petal width (cm) <= 1.75\n",
"| | | |--- petal length (cm) <= 4.95\n",
"| | | | |--- class: 1\n",
"| | | |--- petal length (cm) > 4.95\n",
"| | | | |--- petal width (cm) <= 1.55\n",
"| | | | | |--- class: 2\n",
"| | | | |--- petal width (cm) > 1.55\n",
"| | | | | |--- sepal length (cm) <= 6.95\n",
"| | | | | | |--- class: 1\n",
"| | | | | |--- sepal length (cm) > 6.95\n",
"| | | | | | |--- class: 2\n",
"| | |--- petal width (cm) > 1.75\n",
"| | | |--- petal length (cm) <= 4.85\n",
"| | | | |--- sepal width (cm) <= 3.10\n",
"| | | | | |--- class: 2\n",
"| | | | |--- sepal width (cm) > 3.10\n",
"| | | | | |--- class: 1\n",
"| | | |--- petal length (cm) > 4.85\n",
"| | | | |--- class: 2\n",
"\n",
"Cross Validation Scores: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 1. 1. 1. ]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Clean Code for the above code\n",
"def id3_decision_tree_classification(X_train, y_train, X_test, y_test):\n",
" id3_classifier = DecisionTreeClassifier(criterion=\"entropy\", splitter=\"best\")\n",
" id3_classifier.fit(X_train, y_train)\n",
" id3_predictions = id3_classifier.predict(X_test)\n",
" id3_accuracy = accuracy_score(y_test, id3_predictions)\n",
"\n",
" print(\"ID3 Accuracy:\", id3_accuracy)\n",
" print(\"\\nConfusion matrix: \")\n",
" print(confusion_matrix(y_test, id3_predictions))\n",
"\n",
" precision, recall, f1_score, support = precision_recall_fscore_support(y_test, id3_predictions)\n",
" print(\"\\nPrecision:\", precision)\n",
" print(\"Recall:\", recall)\n",
" print(\"F1-Score:\", f1_score)\n",
" print(\"Support:\", support)\n",
"\n",
" draw_tree_text = export_text(id3_classifier, feature_names=irisdata['feature_names'])\n",
" print(\"\\nDecision Tree (ID3):\\n\", draw_tree_text)\n",
"\n",
" scores = cross_val_score(id3_classifier, X, y, cv=10)\n",
" print(\"\\nCross Validation Scores:\", scores)\n",
"\n",
"def cart_decision_tree_classification(X_train, y_train, X_test, y_test):\n",
" cart_classifier = DecisionTreeClassifier(criterion=\"gini\", splitter=\"best\")\n",
" classified = cart_classifier.fit(X_train, y_train)\n",
" cart_predictions = cart_classifier.predict(X_test)\n",
" cart_accuracy = accuracy_score(y_test, cart_predictions)\n",
"\n",
" print(\"CART Accuracy:\", cart_accuracy)\n",
" print(\"\\nConfusion matrix: \")\n",
" print(confusion_matrix(y_test, cart_predictions))\n",
"\n",
" precision, recall, f1_score, support = precision_recall_fscore_support(y_test, cart_predictions)\n",
" print(\"\\nPrecision:\", precision)\n",
" print(\"Recall:\", recall)\n",
" print(\"F1-Score:\", f1_score)\n",
" print(\"Support:\", support)\n",
"\n",
" draw_tree_text = export_text(classified, feature_names=irisdata['feature_names'])\n",
" print(\"\\nDecision Tree (CART):\\n\", draw_tree_text)\n",
"\n",
" scores = cross_val_score(classified, X, y, cv=10)\n",
" print(\"\\nCross Validation Scores:\", scores)\n",
"\n",
"def main():\n",
" while True:\n",
" print(\"Select an option:\")\n",
" print(\"1. ID3 Decision Tree Classification\")\n",
" print(\"2. CART Decision Tree Classification\")\n",
" print(\"3. Quit\")\n",
"\n",
" choice = input(\"Enter your choice: \")\n",
"\n",
" if choice == '1':\n",
" X_train = X[1:150:2]\n",
" y_train = y[1:150:2]\n",
" X_test = X[0:150:2]\n",
" y_test = y[1:150:2]\n",
" id3_decision_tree_classification(X_train, y_train, X_test, y_test)\n",
" elif choice == '2':\n",
" X_train = X[1:150:2]\n",
" y_train = y[1:150:2]\n",
" X_test = X[0:150:2]\n",
" y_test = y[1:150:2]\n",
" cart_decision_tree_classification(X_train, y_train, X_test, y_test)\n",
" elif choice == '3':\n",
" break\n",
" else:\n",
" print(\"Invalid choice. Please select a valid option.\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oiA0gsZ7Xjwo",
"outputId": "5f04e2a7-9a56-4d27-8a71-a3de6474b244"
},
"execution_count": null,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Select an option:\n",
"1. ID3 Decision Tree Classification\n",
"2. CART Decision Tree Classification\n",
"3. Quit\n",
"Enter your choice: 1\n",
"ID3 Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix: \n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
"Precision: [1. 0.95454545 0.85714286]\n",
"Recall: [1. 0.84 0.96]\n",
"F1-Score: [1. 0.89361702 0.90566038]\n",
"Support: [25 25 25]\n",
"\n",
"Decision Tree (ID3):\n",
" |--- petal length (cm) <= 2.50\n",
"| |--- class: 0\n",
"|--- petal length (cm) > 2.50\n",
"| |--- petal length (cm) <= 4.80\n",
"| | |--- class: 1\n",
"| |--- petal length (cm) > 4.80\n",
"| | |--- petal width (cm) <= 1.75\n",
"| | | |--- petal width (cm) <= 1.55\n",
"| | | | |--- class: 2\n",
"| | | |--- petal width (cm) > 1.55\n",
"| | | | |--- sepal length (cm) <= 6.95\n",
"| | | | | |--- class: 1\n",
"| | | | |--- sepal length (cm) > 6.95\n",
"| | | | | |--- class: 2\n",
"| | |--- petal width (cm) > 1.75\n",
"| | | |--- class: 2\n",
"\n",
"\n",
"Cross Validation Scores: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 0.93333333 1. 1. ]\n",
"Select an option:\n",
"1. ID3 Decision Tree Classification\n",
"2. CART Decision Tree Classification\n",
"3. Quit\n",
"Enter your choice: 2\n",
"CART Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix: \n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
"Precision: [1. 0.95454545 0.85714286]\n",
"Recall: [1. 0.84 0.96]\n",
"F1-Score: [1. 0.89361702 0.90566038]\n",
"Support: [25 25 25]\n",
"\n",
"Decision Tree (CART):\n",
" |--- petal width (cm) <= 0.80\n",
"| |--- class: 0\n",
"|--- petal width (cm) > 0.80\n",
"| |--- petal length (cm) <= 4.80\n",
"| | |--- class: 1\n",
"| |--- petal length (cm) > 4.80\n",
"| | |--- petal width (cm) <= 1.75\n",
"| | | |--- petal width (cm) <= 1.55\n",
"| | | | |--- class: 2\n",
"| | | |--- petal width (cm) > 1.55\n",
"| | | | |--- petal length (cm) <= 5.45\n",
"| | | | | |--- class: 1\n",
"| | | | |--- petal length (cm) > 5.45\n",
"| | | | | |--- class: 2\n",
"| | |--- petal width (cm) > 1.75\n",
"| | | |--- class: 2\n",
"\n",
"\n",
"Cross Validation Scores: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 1. 1. 1. ]\n",
"Select an option:\n",
"1. ID3 Decision Tree Classification\n",
"2. CART Decision Tree Classification\n",
"3. Quit\n",
"Enter your choice: 3\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Load and split the data\n",
"X_train = X[1:150:2]\n",
"y_train = y[1:150:2]\n",
"X_test = X[0:150:2]\n",
"y_test = y[1:150:2]\n",
"\n",
"# Create an ID3 Decision Tree Classifier\n",
"id3_classifier = DecisionTreeClassifier(criterion=\"entropy\", splitter=\"best\")\n",
"\n",
"# Fit the classifier on the training data\n",
"classification = id3_classifier.fit(X_train, y_train)\n",
"\n",
"# Visualize the decision tree\n",
"tree.plot_tree(classification)\n",
"\n",
"# Make predictions on the test data\n",
"id3_predictions = id3_classifier.predict(X_test)\n",
"\n",
"# Calculate and print accuracy\n",
"id3_accuracy = accuracy_score(y_test, id3_predictions)\n",
"print(\"ID3 Accuracy:\", id3_accuracy)\n",
"\n",
"# Print confusion matrix\n",
"print(\"\\nConfusion matrix: \")\n",
"print(confusion_matrix(y_test, id3_predictions))\n",
"\n",
"# Print Precision, Recall, F1-Score, and Support\n",
"print(\"\\n Precision, Recall, F1-Score and Support: \", precision_recall_fscore_support(y_test, id3_predictions),\"\\n\")\n",
"\n",
"# Export decision tree as text\n",
"drawTree = export_text(classification, feature_names=irisdata['feature_names'])\n",
"print(drawTree)\n",
"\n",
"# Perform cross-validation\n",
"scores = cross_val_score(classification, X, y, cv=10)\n",
"print(\"Cross Validation Score: \", scores)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 892
},
"id": "IRmLhWksWLNK",
"outputId": "e3d68631-a685-4bc9-aabc-e21af5a9a95e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"ID3 Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix: \n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
" Precision, Recall, F1-Score and Support: (array([1. , 0.95454545, 0.85714286]), array([1. , 0.84, 0.96]), array([1. , 0.89361702, 0.90566038]), array([25, 25, 25])) \n",
"\n",
"|--- petal width (cm) <= 0.80\n",
"| |--- class: 0\n",
"|--- petal width (cm) > 0.80\n",
"| |--- petal length (cm) <= 4.80\n",
"| | |--- class: 1\n",
"| |--- petal length (cm) > 4.80\n",
"| | |--- petal width (cm) <= 1.75\n",
"| | | |--- petal width (cm) <= 1.55\n",
"| | | | |--- class: 2\n",
"| | | |--- petal width (cm) > 1.55\n",
"| | | | |--- petal length (cm) <= 5.45\n",
"| | | | | |--- class: 1\n",
"| | | | |--- petal length (cm) > 5.45\n",
"| | | | | |--- class: 2\n",
"| | |--- petal width (cm) > 1.75\n",
"| | | |--- class: 2\n",
"\n",
"Cross Validation Score: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 1. 1. 1. ]\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# Split the data into training and testing sets\n",
"X_train = X[1:150:2]\n",
"y_train = y[1:150:2]\n",
"X_test = X[0:150:2]\n",
"y_test = y[1:150:2]\n",
"\n",
"# Create a Decision Tree Classifier\n",
"cart_classifier = DecisionTreeClassifier(criterion=\"gini\", splitter=\"best\")\n",
"\n",
"# Train the classifier\n",
"classified = cart_classifier.fit(X_train, y_train)\n",
"\n",
"# Make predictions\n",
"cart_predictions = cart_classifier.predict(X_test)\n",
"\n",
"# Calculate accuracy\n",
"cart_accuracy = accuracy_score(y_test, cart_predictions)\n",
"print(\"CART Accuracy:\", cart_accuracy)\n",
"\n",
"# Display confusion matrix\n",
"print(\"\\nConfusion matrix:\")\n",
"print(confusion_matrix(y_test, cart_predictions))\n",
"\n",
"# Display Precision, Recall, F1-Score, and Support\n",
"precision, recall, f1_score, support = precision_recall_fscore_support(y_test, id3_predictions)\n",
"print(\"\\nPrecision:\", precision)\n",
"print(\"Recall:\", recall)\n",
"print(\"F1-Score:\", f1_score)\n",
"print(\"Support:\", support)\n",
"\n",
"# Display the decision tree\n",
"draw_tree = export_text(classified, feature_names=irisdata['feature_names'])\n",
"print(draw_tree)\n",
"\n",
"# Perform Cross-Validation\n",
"scores = cross_val_score(classified, X, y, cv=10)\n",
"print(\"\\nCross Validation Score:\", scores)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fBoTZ-SgWY1v",
"outputId": "d117bbef-0ca4-4d60-a87f-d796a2a5eab3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CART Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix:\n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
"Precision: [1. 0.95454545 0.85714286]\n",
"Recall: [1. 0.84 0.96]\n",
"F1-Score: [1. 0.89361702 0.90566038]\n",
"Support: [25 25 25]\n",
"|--- petal width (cm) <= 0.80\n",
"| |--- class: 0\n",
"|--- petal width (cm) > 0.80\n",
"| |--- petal length (cm) <= 4.80\n",
"| | |--- class: 1\n",
"| |--- petal length (cm) > 4.80\n",
"| | |--- petal width (cm) <= 1.75\n",
"| | | |--- petal width (cm) <= 1.55\n",
"| | | | |--- class: 2\n",
"| | | |--- petal width (cm) > 1.55\n",
"| | | | |--- sepal length (cm) <= 6.95\n",
"| | | | | |--- class: 1\n",
"| | | | |--- sepal length (cm) > 6.95\n",
"| | | | | |--- class: 2\n",
"| | |--- petal width (cm) > 1.75\n",
"| | | |--- class: 2\n",
"\n",
"\n",
"Cross Validation Score: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 0.93333333 1. 1. ]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Clean Code for the above code\n",
"def load_and_split_data(X, y):\n",
" X_train = X[1:150:2]\n",
" y_train = y[1:150:2]\n",
" X_test = X[0:150:2]\n",
" y_test = y[1:150:2]\n",
" return X_train, y_train, X_test, y_test\n",
"\n",
"def train_decision_tree_classifier(X_train, y_train, criterion=\"entropy\"):\n",
" classifier = DecisionTreeClassifier(criterion=criterion, splitter=\"best\")\n",
" classification = classifier.fit(X_train, y_train)\n",
" return classification\n",
"\n",
"def visualize_decision_tree(classification, feature_names):\n",
" export_text(classification, feature_names=feature_names)\n",
"\n",
"def evaluate_decision_tree(classification, X_test, y_test):\n",
" predictions = classification.predict(X_test)\n",
" accuracy = accuracy_score(y_test, predictions)\n",
" print(\"Accuracy:\", accuracy)\n",
"\n",
" print(\"\\nConfusion matrix:\")\n",
" print(confusion_matrix(y_test, predictions))\n",
"\n",
" precision, recall, f1_score, support = precision_recall_fscore_support(y_test, predictions)\n",
" print(\"\\nPrecision:\", precision)\n",
" print(\"Recall:\", recall)\n",
" print(\"F1-Score:\", f1_score)\n",
" print(\"Support:\", support)\n",
"\n",
"def perform_cross_validation(classification, X, y):\n",
" scores = cross_val_score(classification, X, y, cv=10)\n",
" print(\"\\nCross Validation Score:\", scores)\n",
"\n",
"if __name__ == \"__main__\":\n",
" # Load and split data (Assuming you have X and y)\n",
" X_train, y_train, X_test, y_test = load_and_split_data(X, y)\n",
"\n",
" while True:\n",
" print(\"\\nSelect an option:\")\n",
" print(\"1. Train ID3 Decision Tree\")\n",
" print(\"2. Train CART Decision Tree\")\n",
" print(\"3. Exit\")\n",
"\n",
" choice = input(\"Enter your choice: \")\n",
"\n",
" if choice == \"1\":\n",
" classification = train_decision_tree_classifier(X_train, y_train, criterion=\"entropy\")\n",
" visualize_decision_tree(classification, feature_names=irisdata['feature_names'])\n",
" evaluate_decision_tree(classification, X_test, y_test)\n",
" perform_cross_validation(classification, X, y)\n",
"\n",
" elif choice == \"2\":\n",
" classification = train_decision_tree_classifier(X_train, y_train, criterion=\"gini\")\n",
" visualize_decision_tree(classification, feature_names=irisdata['feature_names'])\n",
" evaluate_decision_tree(classification, X_test, y_test)\n",
" perform_cross_validation(classification, X, y)\n",
"\n",
" elif choice == \"3\":\n",
" break\n",
"\n",
" else:\n",
" print(\"Invalid choice. Please select a valid option.\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DfvFp9TzYBZF",
"outputId": "c87e1417-08f0-4e96-d003-040797f9a038"
},
"execution_count": null,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Select an option:\n",
"1. Train ID3 Decision Tree\n",
"2. Train CART Decision Tree\n",
"3. Exit\n",
"Enter your choice: 1\n",
"Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix:\n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
"Precision: [1. 0.95454545 0.85714286]\n",
"Recall: [1. 0.84 0.96]\n",
"F1-Score: [1. 0.89361702 0.90566038]\n",
"Support: [25 25 25]\n",
"\n",
"Cross Validation Score: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 1. 1. 1. ]\n",
"\n",
"Select an option:\n",
"1. Train ID3 Decision Tree\n",
"2. Train CART Decision Tree\n",
"3. Exit\n",
"Enter your choice: 2\n",
"Accuracy: 0.9333333333333333\n",
"\n",
"Confusion matrix:\n",
"[[25 0 0]\n",
" [ 0 21 4]\n",
" [ 0 1 24]]\n",
"\n",
"Precision: [1. 0.95454545 0.85714286]\n",
"Recall: [1. 0.84 0.96]\n",
"F1-Score: [1. 0.89361702 0.90566038]\n",
"Support: [25 25 25]\n",
"\n",
"Cross Validation Score: [1. 0.93333333 1. 0.93333333 0.93333333 0.86666667\n",
" 0.93333333 0.93333333 1. 1. ]\n",
"\n",
"Select an option:\n",
"1. Train ID3 Decision Tree\n",
"2. Train CART Decision Tree\n",
"3. Exit\n",
"Enter your choice: 3\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Q38qx7otYCrk"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment