Skip to content

Instantly share code, notes, and snippets.

# aflaxman/milk_tree_learner.ipynb Created Nov 28, 2012

milk_tree_learner
 # to come Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
 # -*- coding: utf-8 -*- # 3.0 # import numpy as np import milk # # Simulate some labeled data in a two-dimensional feature space. # features = np.random.randn(100,2) # 2d array of features: 100 examples of 2 features each labels = np.empty(100) for i in range(100): if features[i,0] < 0: if features[i,1] < -1: labels[i] = np.random.rand() < .001 else: labels[i] = np.random.rand() < .999 else: if features[i,1] < 1: labels[i] = np.random.rand() < .001 else: labels[i] = np.random.rand() < .999 # # What is the decision tree for this data? # # Since the data is two-dimensional, we can take a look at it easily. # plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=6, mew=3) plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1) grid() # # Fitting the model is easy: # learner = milk.supervised.tree_learner() model = learner.train(features, labels) # # Using it is easy, too: # model.apply([-1,1]) # # Visualizing the decision boundary is a bit of a pain... # x_range = np.linspace(-3,3,100) y_range = np.linspace(-3,3,100) val = np.zeros((len(x_range), len(y_range))) for i, x_i in enumerate(x_range): for j, y_j in enumerate(y_range): val[i,j] = model.apply([x_i,y_j]) imshow(val[::1,::-1].T, extent=[x_range,x_range[-1],y_range,y_range[-1]], cmap=cm.Greys) plot(features[labels==True,0], features[labels==True,1], 'kx', mec='w', ms=7, mew=5) plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=5, mew=3) plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1) grid() # # And can we have a picture of the decision tree itself? It is hidden in the model instance somewhere... # model.tree # model.tree.featid, model.tree.featval, model.tree.left, model.tree.right # def describe_tree(node, prefix=''): print prefix + 'if x[%d] < %.2f:' % (node.featid, node.featval) if isinstance(node.left, milk.supervised.tree.Node): describe_tree(node.left, prefix+' ') else: print prefix+' ', node.left print prefix + 'else:' if isinstance(node.right, milk.supervised.tree.Node): describe_tree(node.right, prefix+' ') else: print prefix+' ', node.right # describe_tree(model.tree) # # Not as simple as it seemed in the lecture, huh?
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.