Skip to content

Instantly share code, notes, and snippets.

@viraj-lakshitha
Created January 21, 2021 09:17
Show Gist options
  • Save viraj-lakshitha/a403b862f788b807dea7e4072de964a2 to your computer and use it in GitHub Desktop.
Save viraj-lakshitha/a403b862f788b807dea7e4072de964a2 to your computer and use it in GitHub Desktop.
decision-tree-algorithm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "decision-tree-algorithm.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyM03jZeatM2qfX/S/t5Xgs7",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/viraj-lakshitha/a403b862f788b807dea7e4072de964a2/decision-tree-algorithm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SiKvv3hOpPhP"
},
"source": [
"# Decision Tree Algorithm\r\n",
"From Scratch using Python"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bcaPnL6d-Rfl"
},
"source": [
"# For Python 2 / 3 compatability\r\n",
"from __future__ import print_function"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6H6u3G-6pIWB"
},
"source": [
"# First two rows are features and last column\r\n",
"training_data = [\r\n",
" ['Green', 3, 'Apple'],\r\n",
" ['Yellow', 3, 'Apple'],\r\n",
" ['Red', 1, 'Grape'],\r\n",
" ['Blue', 2, 'Blueberry'],\r\n",
" ['Red', 1, 'Grape'],\r\n",
" ['Yellow', 3, 'Lemon'],\r\n",
"]\r\n",
"\r\n",
"# Column Labels (Used to print the tree)\r\n",
"headers = [\"colors\",'diameter','label']"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AeQ5ku_H4WKg",
"outputId": "00293419-0205-4eda-c4d4-aabe4942daf5"
},
"source": [
"#Define a function to find a unique value for the column in the dataset\r\n",
"def unique_vals(rows, col):\r\n",
" return set([row[col] for row in rows])\r\n",
"\r\n",
"#Test the function\r\n",
"print(unique_vals(training_data, 0)) # Unique value in the first column"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"{'Green', 'Yellow', 'Red', 'Blue'}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8u1_w8RM4i1K",
"outputId": "62a50c0a-1957-401c-c4fa-b8ac6b28f6b1"
},
"source": [
"#Define a function to count the number of each type of data in a dataset\r\n",
"def class_counts(rows):\r\n",
" counts = {} # a dictionary of label -> count.\r\n",
" for row in rows:\r\n",
" # in our dataset format, the label is always the last column\r\n",
" label = row[-1]\r\n",
" if label not in counts:\r\n",
" counts[label] = 0\r\n",
" counts[label] += 1\r\n",
" return counts\r\n",
"\r\n",
"\r\n",
"#Test the function\r\n",
"class_counts(training_data)"
],
"execution_count": 24,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Apple': 2, 'Blueberry': 1, 'Grape': 2, 'Lemon': 1}"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Re8KOtsh6TEP",
"outputId": "4f51d7f5-b7e2-4311-ea96-5efbc5485e8d"
},
"source": [
"#Define a function to check the value is interger or numeric\r\n",
"def is_numeric(value):\r\n",
" return isinstance(value, int) or isinstance(value, float)\r\n",
"\r\n",
"#Test the function\r\n",
"is_numeric(10)"
],
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sm3f1uC-9dTW"
},
"source": [
"**Question class is used to partition a dataset.**\r\n",
"This class just records a 'column number' (e.g., 0 for Color) and a\r\n",
"'column value' (e.g., Green).The 'match' method is used to compare\r\n",
"the feature value in an example to the feature value stored in the\r\n",
"question. See the demo below"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Il0ovtvd9HeG",
"outputId": "5fd29d96-3e4e-46c0-a2e6-ce17c49800dd"
},
"source": [
"class Question:\r\n",
"\r\n",
" def __init__(self, column, value):\r\n",
" self.column = column\r\n",
" self.value = value\r\n",
"\r\n",
"# Compare the feature value in an example to the\r\n",
"# feature value in this question.\r\n",
" def match(self, example):\r\n",
" val = example[self.column]\r\n",
" if is_numeric(val):\r\n",
" return val >= self.value\r\n",
" else:\r\n",
" return val == self.value\r\n",
"\r\n",
"# This is just a helper method to print\r\n",
"# the question in a readable format.\r\n",
" def __repr__(self):\r\n",
" condition = \"==\"\r\n",
" if is_numeric(self.value):\r\n",
" condition = \">=\"\r\n",
" return \"Is %s %s %s ?\" % (\r\n",
" headers[self.column], condition, str(self.value))\r\n",
" \r\n",
"\r\n",
"#Test the class\r\n",
"print(Question(1, 3))\r\n",
"print(Question(2, \"Apple\"))\r\n",
"print(Question(0, \"Blue\"))"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": [
"Is diameter >= 3 ?\n",
"Is label == Apple ?\n",
"Is colors == Blue ?\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pSVucmj4-74q",
"outputId": "9b85d4dd-a33f-4300-b83f-29fbebda850b"
},
"source": [
"#Check the Quection and Dataset working properly or not ?\r\n",
"# Let's create quection from the Question class\r\n",
"question1 = Question(0,\"Blue\")\r\n",
"question2 = Question(0,\"Red\")\r\n",
"\r\n",
"# Let's take a sample data from the training_data\r\n",
"sample_data = training_data[3] # ['Blue', 2, 'Blueberry']"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "crc7Z8crAQJ-",
"outputId": "0c583588-69fc-456a-d39f-2385dd2a5961"
},
"source": [
"# Check\r\n",
"question1.match(sample_data) # should be TRUE"
],
"execution_count": 37,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YljlGpwZATRM",
"outputId": "c910b55a-27af-4bc1-f3cb-3835eabde064"
},
"source": [
"# Check\r\n",
"question2.match(sample_data) # should be FALSE, because the colour of blueberry should be BLUE"
],
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"False"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-2m87nEqCAY7",
"outputId": "bdff1f08-13f5-400d-ee27-c87968e62951"
},
"source": [
"# Define a function to partion data in to TRUE and FALSE columns\r\n",
"def partition(rows, question):\r\n",
" true_rows, false_rows = [], [] # Define two list for FALSE and TRUE\r\n",
" for row in rows:\r\n",
" if question.match(row):\r\n",
" true_rows.append(row)\r\n",
" else:\r\n",
" false_rows.append(row)\r\n",
" return true_rows, false_rows\r\n",
"\r\n",
"# Test the function using Partition Function and Question Function\r\n",
"true_rows, false_rows = partition(training_data, Question(0, 'Blue'))\r\n",
"print('TRUE Dataset : ',true_rows)\r\n",
"print('FALSE Dataset : ',false_rows)"
],
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"text": [
"TRUE Dataset : [['Blue', 2, 'Blueberry']]\n",
"FALSE Dataset : [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Red', 1, 'Grape'], ['Red', 1, 'Grape'], ['Yellow', 3, 'Lemon']]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JlEKy2KgERtP"
},
"source": [
"Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it was randomly labeled according to the distribution of labels in the subset.\r\n",
"\r\n",
"Different between Gini and Entropy [Click](https://quantdare.com/decision-trees-gini-vs-entropy/)\r\n",
"\r\n",
"![1_gc1d1Sf8F7NoXAEnRlitbg.png]()"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xi4VLvQrE4jV"
},
"source": [
"# There are various method to calculate the Gini Impurity for a list of rows\r\n",
"# Define a function to calculate Gini Impurity\r\n",
"def gini(rows):\r\n",
" counts = class_counts(rows)\r\n",
" impurity = 1\r\n",
" for lbl in counts:\r\n",
" prob_of_lbl = counts[lbl] / float(len(rows))\r\n",
" impurity -= prob_of_lbl**2\r\n",
" return impurity"
],
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6bXExSkmFX0K",
"outputId": "7c927e94-dc7c-4a38-a582-70b941a3d168"
},
"source": [
"# Test the gini() function with un-mixed list of data\r\n",
"test_gini_unmixed = [['Apple'],['Apple']]\r\n",
"gini(test_gini_unmixed) # Should be 0, because the no different data"
],
"execution_count": 44,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WTSUYqlCGU4q",
"outputId": "215d4500-2c38-4dd9-bff6-d3f707aa0ae3"
},
"source": [
"# Test the gini() function with mixed list of data\r\n",
"test_gini_unmixed = [['Lemon'],['Apple']]\r\n",
"gini(test_gini_unmixed) # Should be 0.5, because there two type of data"
],
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.5"
]
},
"metadata": {
"tags": []
},
"execution_count": 48
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VUvfNY2HJvy"
},
"source": [
"Information gain is the reduction in entropy or surprise by transforming a dataset and is often used in training decision trees. Information gain is calculated by comparing the entropy of the dataset before and after a transformation\r\n",
"\r\n",
"Read More : Gini Index and Information Gain | Entropy and Information Gain [Click](http://www.learnbymarketing.com/481/decision-tree-flavors-gini-info-gain/#:~:text=Summary%3A%20The%20Gini%20Index%20is,of%20each%20class%20from%20one.&text=Information%20Gain%20multiplies%20the%20probability,2)\r\n",
"\r\n",
"The uncertainty of the starting node, minus the weighted impurity of two child nodes.\r\n",
" "
]
},
{
"cell_type": "code",
"metadata": {
"id": "8eShByF9HJcK"
},
"source": [
"# Define the function to calculate the Information Gain of the tree\r\n",
"def info_gain(left, right, current_uncertainty):\r\n",
" p = float(len(left)) / (len(left) + len(right))\r\n",
" return current_uncertainty - p * gini(left) - (1 - p) * gini(right)"
],
"execution_count": 51,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1pcHOWuaIoqS",
"outputId": "d91ce0e3-c875-47fb-96ec-8fc96760d7df"
},
"source": [
"# Check the all function\r\n",
"current_uncertainty = gini(training_data)\r\n",
"print('Gini Impurities : ',current_uncertainty)"
],
"execution_count": 52,
"outputs": [
{
"output_type": "stream",
"text": [
"Gini Impurities : 0.7222222222222221\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J5SV2tvuJ_Dy",
"outputId": "89926112-b41d-40be-d1e9-7fdfc282eb45"
},
"source": [
"# Data Gaining when partitioning on GREEN\r\n",
"true_rows, false_rows = partition(training_data, Question(0, 'Green'))\r\n",
"print('Information Gain of GREEN : ',info_gain(true_rows, false_rows, current_uncertainty))"
],
"execution_count": 53,
"outputs": [
{
"output_type": "stream",
"text": [
"Information Gain of GREEN : 0.12222222222222223\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bf03cs94KBzg",
"outputId": "235fb6b3-5961-4a76-e2a8-1a35ec329fa7"
},
"source": [
"# Data Gaining when partitioning on RED\r\n",
"true_rows, false_rows = partition(training_data, Question(0, 'Red'))\r\n",
"print('Information Gain of RED : ',info_gain(true_rows, false_rows, current_uncertainty))"
],
"execution_count": 54,
"outputs": [
{
"output_type": "stream",
"text": [
"Information Gain of RED : 0.30555555555555536\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OMeCB_vqKE4h",
"outputId": "85d2f4c5-29f5-4589-cf15-2f48dcd96f9d"
},
"source": [
"# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14).\r\n",
"# Why? Look at the different splits that result, and see which one\r\n",
"# looks more 'unmixed' to you.\r\n",
"true_rows, false_rows = partition(training_data, Question(0,'Red'))\r\n",
"\r\n",
"print('TRUE Dataset : ',true_rows)\r\n",
"print('FALSE Dataset : ',false_rows)"
],
"execution_count": 55,
"outputs": [
{
"output_type": "stream",
"text": [
"TRUE Dataset : [['Red', 1, 'Grape'], ['Red', 1, 'Grape']]\n",
"FALSE Dataset : [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Blue', 2, 'Blueberry'], ['Yellow', 3, 'Lemon']]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "R095r-iEK0I3"
},
"source": [
"#Define a model to find the best question to ask by iterating over every feature / value and calculating the information gain\r\n",
"def find_best_split(rows):\r\n",
" best_gain = 0 # keep track of the best information gain\r\n",
" best_question = None # keep train of the feature / value that produced it\r\n",
" current_uncertainty = gini(rows)\r\n",
" n_features = len(rows[0]) - 1 # number of columns\r\n",
"\r\n",
" for col in range(n_features): # for each feature\r\n",
"\r\n",
" values = set([row[col] for row in rows]) # unique values in the column\r\n",
"\r\n",
" for val in values: # for each value\r\n",
"\r\n",
" question = Question(col, val)\r\n",
"\r\n",
" # try splitting the dataset\r\n",
" true_rows, false_rows = partition(rows, question)\r\n",
"\r\n",
" # Skip this split if it doesn't divide the\r\n",
" # dataset.\r\n",
" if len(true_rows) == 0 or len(false_rows) == 0:\r\n",
" continue\r\n",
"\r\n",
" # Calculate the information gain from this split\r\n",
" gain = info_gain(true_rows, false_rows, current_uncertainty)\r\n",
"\r\n",
" # You actually can use '>' instead of '>=' here\r\n",
" # but I wanted the tree to look a certain way\r\n",
" if gain >= best_gain:\r\n",
" best_gain, best_question = gain, question\r\n",
"\r\n",
" return best_gain, best_question"
],
"execution_count": 57,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "57Zz1xciLbxK",
"outputId": "001628a8-e1fa-46be-defb-d9801d10e94a"
},
"source": [
"#Test the Function\r\n",
"best_gain, best_question = find_best_split(training_data)\r\n",
"best_question"
],
"execution_count": 58,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Is diameter >= 2 ?"
]
},
"metadata": {
"tags": []
},
"execution_count": 58
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KzCJKJLWLo-g"
},
"source": [
"# Create two classes as Leaf and Decision_Node. Leaf classify the data and Decision_Node ask thwe quecction\r\n",
"\r\n",
"# Lead Class\r\n",
"class Leaf:\r\n",
" def __init__(self, rows):\r\n",
" self.predictions = class_counts(rows)\r\n",
"\r\n",
"# Decision_Node Class\r\n",
"class Decision_Node:\r\n",
" def __init__(self,\r\n",
" question,\r\n",
" true_branch,\r\n",
" false_branch):\r\n",
" self.question = question\r\n",
" self.true_branch = true_branch\r\n",
" self.false_branch = false_branch"
],
"execution_count": 59,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wm0Yn3bJMIro"
},
"source": [
"# Define Function to build the tree\r\n",
"def build_tree(rows):\r\n",
"\r\n",
" # Try partitioing the dataset on each of the unique attribute,\r\n",
" # calculate the information gain,\r\n",
" # and return the question that produces the highest gain.\r\n",
" gain, question = find_best_split(rows)\r\n",
"\r\n",
" # Base case: no further info gain\r\n",
" # Since we can ask no further questions,\r\n",
" # we'll return a leaf.\r\n",
" if gain == 0:\r\n",
" return Leaf(rows)\r\n",
"\r\n",
" # If we reach here, we have found a useful feature / value\r\n",
" # to partition on.\r\n",
" true_rows, false_rows = partition(rows, question)\r\n",
"\r\n",
" # Recursively build the true branch.\r\n",
" true_branch = build_tree(true_rows)\r\n",
"\r\n",
" # Recursively build the false branch.\r\n",
" false_branch = build_tree(false_rows)\r\n",
"\r\n",
" # Return a Question node.\r\n",
" # This records the best feature / value to ask at this point,\r\n",
" # as well as the branches to follow\r\n",
" # dependingo on the answer.\r\n",
" return Decision_Node(question, true_branch, false_branch)"
],
"execution_count": 60,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1uCt6HnFMfZF"
},
"source": [
"# Define functionto print the tree\r\n",
"def print_tree(node, spacing=\"\"):\r\n",
"\r\n",
" # Base case: we've reached a leaf\r\n",
" if isinstance(node, Leaf):\r\n",
" print (spacing + \"Predict\", node.predictions)\r\n",
" return\r\n",
"\r\n",
" # Print the question at this node\r\n",
" print (spacing + str(node.question))\r\n",
"\r\n",
" # Call this function recursively on the true branch\r\n",
" print (spacing + '--> True:')\r\n",
" print_tree(node.true_branch, spacing + \" \")\r\n",
"\r\n",
" # Call this function recursively on the false branch\r\n",
" print (spacing + '--> False:')\r\n",
" print_tree(node.false_branch, spacing + \" \")"
],
"execution_count": 65,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rD-AxWKCMq-Z"
},
"source": [
"my_tree = build_tree(training_data)"
],
"execution_count": 62,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nBBZ21UuM0JS",
"outputId": "a8f07c9e-427f-4439-f5c8-4d1bf5f2d531"
},
"source": [
"print_tree(my_tree)"
],
"execution_count": 66,
"outputs": [
{
"output_type": "stream",
"text": [
"Is diameter >= 2 ?\n",
"--> True:\n",
" Is diameter >= 3 ?\n",
" --> True:\n",
" Is colors == Yellow ?\n",
" --> True:\n",
" Predict {'Apple': 1, 'Lemon': 1}\n",
" --> False:\n",
" Predict {'Apple': 1}\n",
" --> False:\n",
" Predict {'Blueberry': 1}\n",
"--> False:\n",
" Predict {'Grape': 2}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gXMALkx-NSdp",
"outputId": "29a5feef-0f52-4966-d16a-912d37244814"
},
"source": [
"# Define a function to get the all rules of recursion\r\n",
"def classify(row, node):\r\n",
"\r\n",
" # Base case: we've reached a leaf\r\n",
" if isinstance(node, Leaf):\r\n",
" return node.predictions\r\n",
"\r\n",
" # Decide whether to follow the true-branch or the false-branch.\r\n",
" # Compare the feature / value stored in the node,\r\n",
" # to the example we're considering.\r\n",
" if node.question.match(row):\r\n",
" return classify(row, node.true_branch)\r\n",
" else:\r\n",
" return classify(row, node.false_branch)\r\n",
"\r\n",
"# Test\r\n",
"classify(training_data[0], my_tree)"
],
"execution_count": 67,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Apple': 1}"
]
},
"metadata": {
"tags": []
},
"execution_count": 67
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "igWCuRQJNm8L",
"outputId": "fd8d4f61-9fc3-4231-fe16-b4265dce0925"
},
"source": [
"# Define function to get the probability of decision\r\n",
"def print_leaf(counts):\r\n",
" total = sum(counts.values()) * 1.0\r\n",
" probs = {}\r\n",
" for lbl in counts.keys():\r\n",
" probs[lbl] = str(int(counts[lbl] / total * 100)) + \"%\"\r\n",
" return probs\r\n",
"\r\n",
"# Test 1\r\n",
"print_leaf(classify(training_data[0], my_tree))"
],
"execution_count": 69,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Apple': '50%', 'Lemon': '50%'}"
]
},
"metadata": {
"tags": []
},
"execution_count": 69
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UVFh9XtjN-Uy",
"outputId": "d0d0f5d7-9a0b-4418-809b-26067b320895"
},
"source": [
"# Test 2\r\n",
"print_leaf(classify(training_data[1], my_tree))"
],
"execution_count": 70,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Apple': '50%', 'Lemon': '50%'}"
]
},
"metadata": {
"tags": []
},
"execution_count": 70
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Tpq-btm8OJfB",
"outputId": "7dd804ea-a953-4f88-dccd-f890a378ce0a"
},
"source": [
"# Testing the Model \r\n",
"testing_data = [\r\n",
" ['Green', 3, 'Apple'],\r\n",
" ['Yellow', 4, 'Apple'],\r\n",
" ['Red', 2, 'Grape'],\r\n",
" ['Red', 1, 'Grape'],\r\n",
" ['Yellow', 3, 'Lemon'],\r\n",
"]\r\n",
"\r\n",
"for row in testing_data:\r\n",
" print (\"Actual: %s. Predicted: %s\" %\r\n",
" (row[-1], print_leaf(classify(row, my_tree))))"
],
"execution_count": 71,
"outputs": [
{
"output_type": "stream",
"text": [
"Actual: Apple. Predicted: {'Apple': '100%'}\n",
"Actual: Apple. Predicted: {'Apple': '50%', 'Lemon': '50%'}\n",
"Actual: Grape. Predicted: {'Blueberry': '100%'}\n",
"Actual: Grape. Predicted: {'Grape': '100%'}\n",
"Actual: Lemon. Predicted: {'Apple': '50%', 'Lemon': '50%'}\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment