Skip to content

Instantly share code, notes, and snippets.

@rishisidhu
Created February 21, 2020 07:44
Show Gist options
  • Save rishisidhu/ceed9bca9cabaac267e2272b5485c8cd to your computer and use it in GitHub Desktop.
Save rishisidhu/ceed9bca9cabaac267e2272b5485c8cd to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
from graphviz import Source
iris = load_iris()
print("Feature Names - ", iris.feature_names,"\n")
print(iris.target)
#Print the row 0,50 and 100 i.e. 1 example for each type
print("\nSetosa flower 1 - ",iris.data[0])
print("Versicolor flower 1 - ",iris.data[50])
print("Virginica flower 1 - ",iris.data[100],"\n")
test_indices = [0,1,50,51,100,101]
#training data
train_target = np.delete(iris.target, test_indices)
train_data = np.delete(iris.data, test_indices, axis=0)
#testing data
test_target = iris.target[test_indices]
test_data = iris.data[test_indices]
#Build The Classifier
dtClassifier = tree.DecisionTreeClassifier()
#Train The Classifier
dtClassifier.fit(train_data, train_target)
#Print the actual labels of each test point
print("\n********** Actual **************")
for p in range(len(test_indices)):
print("Test Row ",test_indices[p], " belongs to the class ",test_target[p] )
predicted_target = (dtClassifier.predict(test_data))
#Print the predicted labels of each test point
print("\n********** Predicted **************")
for p in range(len(test_indices)):
print("Test Row ",test_indices[p], " is predicted to be of the class ", predicted_target[p] )
#Visualize The Decision Tree
graph = Source(tree.export_graphviz(dtClassifier, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, node_ids= True,
special_characters=True))
graph.format = 'png'
graph.render('dtree_render',view=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment