Created
November 28, 2012 01:24
-
-
Save aflaxman/4158428 to your computer and use it in GitHub Desktop.
milk_tree_learner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# to come |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# <nbformat>3.0</nbformat> | |
# <codecell> | |
import numpy as np | |
import milk | |
# <markdowncell> | |
# Simulate some labeled data in a two-dimensional feature space. | |
# <codecell> | |
features = np.random.randn(100,2) # 2d array of features: 100 examples of 2 features each | |
labels = np.empty(100) | |
for i in range(100): | |
if features[i,0] < 0: | |
if features[i,1] < -1: | |
labels[i] = np.random.rand() < .001 | |
else: | |
labels[i] = np.random.rand() < .999 | |
else: | |
if features[i,1] < 1: | |
labels[i] = np.random.rand() < .001 | |
else: | |
labels[i] = np.random.rand() < .999 | |
# <markdowncell> | |
# What is the decision tree for this data? | |
# <markdowncell> | |
# Since the data is two-dimensional, we can take a look at it easily. | |
# <codecell> | |
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=6, mew=3) | |
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1) | |
grid() | |
# <markdowncell> | |
# Fitting the model is easy: | |
# <codecell> | |
learner = milk.supervised.tree_learner() | |
model = learner.train(features, labels) | |
# <markdowncell> | |
# Using it is easy, too: | |
# <codecell> | |
model.apply([-1,1]) | |
# <markdowncell> | |
# Visualizing the decision boundary is a bit of a pain... | |
# <codecell> | |
x_range = np.linspace(-3,3,100) | |
y_range = np.linspace(-3,3,100) | |
val = np.zeros((len(x_range), len(y_range))) | |
for i, x_i in enumerate(x_range): | |
for j, y_j in enumerate(y_range): | |
val[i,j] = model.apply([x_i,y_j]) | |
imshow(val[::1,::-1].T, extent=[x_range[0],x_range[-1],y_range[0],y_range[-1]], cmap=cm.Greys) | |
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='w', ms=7, mew=5) | |
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=5, mew=3) | |
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1) | |
grid() | |
# <markdowncell> | |
# And can we have a picture of the decision tree itself? It is hidden in the model instance somewhere... | |
# <codecell> | |
model.tree | |
# <codecell> | |
model.tree.featid, model.tree.featval, model.tree.left, model.tree.right | |
# <codecell> | |
def describe_tree(node, prefix=''): | |
print prefix + 'if x[%d] < %.2f:' % (node.featid, node.featval) | |
if isinstance(node.left, milk.supervised.tree.Node): | |
describe_tree(node.left, prefix+' ') | |
else: | |
print prefix+' ', node.left | |
print prefix + 'else:' | |
if isinstance(node.right, milk.supervised.tree.Node): | |
describe_tree(node.right, prefix+' ') | |
else: | |
print prefix+' ', node.right | |
# <codecell> | |
describe_tree(model.tree) | |
# <markdowncell> | |
# Not as simple as it seemed in the lecture, huh? | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment