Skip to content

Instantly share code, notes, and snippets.

@allisontharp
Last active April 22, 2016 00:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allisontharp/2377cd94fd6ea660cb04b1b2b1204ad5 to your computer and use it in GitHub Desktop.
Save allisontharp/2377cd94fd6ea660cb04b1b2b1204ad5 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 21 19:11:44 2016
@author: Allison
"""
import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
print "feature names: " + str(iris.feature_names)
print "target names: " + str(iris.target_names)
print "first data measurements: " + str(iris.data[0])
print "first data label: " + str(iris.target[0])
test_idx = [0, 50, 100] # index of measurements to remove (one of each type)
train_target = np.delete(iris.target, test_idx) # training labels (w/o testing)
train_data = np.delete(iris.data, test_idx, axis=0) # training data (w/o testing)
test_target = iris.target[test_idx]
test_data = iris.data[test_idx]
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)
# visualize tree
from sklearn.externals.six import StringIO
import pydotplus
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")
@allisontharp
Copy link
Author

This code is an introduction to machine learning. In it, we learn how to navigate the Iris data set, train a classifier, predict an outcome, and visualize the data tree. For more information, visit my blog: www.techtrek.io

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment