Last active
April 22, 2016 00:22
-
-
Save allisontharp/2377cd94fd6ea660cb04b1b2b1204ad5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Apr 21 19:11:44 2016 | |
@author: Allison | |
""" | |
import numpy as np | |
from sklearn.datasets import load_iris | |
from sklearn import tree | |
iris = load_iris() | |
print "feature names: " + str(iris.feature_names) | |
print "target names: " + str(iris.target_names) | |
print "first data measurements: " + str(iris.data[0]) | |
print "first data label: " + str(iris.target[0]) | |
test_idx = [0, 50, 100] # index of measurements to remove (one of each type) | |
train_target = np.delete(iris.target, test_idx) # training labels (w/o testing) | |
train_data = np.delete(iris.data, test_idx, axis=0) # training data (w/o testing) | |
test_target = iris.target[test_idx] | |
test_data = iris.data[test_idx] | |
clf = tree.DecisionTreeClassifier() | |
clf.fit(train_data, train_target) | |
# visualize tree | |
from sklearn.externals.six import StringIO | |
import pydotplus | |
dot_data = StringIO() | |
tree.export_graphviz(clf, out_file=dot_data, | |
feature_names=iris.feature_names, | |
class_names=iris.target_names, | |
filled=True, rounded=True, | |
special_characters=True) | |
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) | |
graph.write_pdf("iris.pdf") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This code is an introduction to machine learning. In it, we learn how to navigate the Iris data set, train a classifier, predict an outcome, and visualize the data tree. For more information, visit my blog: www.techtrek.io