Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.cross_validation import train_test_split
from sklearn import datasets
import matplotlib.pyplot as plt
import myplot as plt2
def main():
iris = datasets.load_iris()
X, Y =[:, [2,3]],
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
tree = DecisionTreeClassifier(criterion="entropy", max_depth=3, random_state=0), y_train)
plt2.plot_decision_regions(X, Y, classifier=tree, test_idx=(105, 150))
plt.legend(loc="upper left")
export_graphviz(tree, out_file="", feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True)
# dot -Teps -o tree.png
if __name__ == "__main__":
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment