Create a gist now

Instantly share code, notes, and snippets.

milk_tree_learner
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# -*- coding: utf-8 -*-
# <nbformat>3.0</nbformat>
# <codecell>
import numpy as np
import milk
# <markdowncell>
# Simulate some labeled data in a two-dimensional feature space.
# <codecell>
features = np.random.randn(100,2) # 2d array of features: 100 examples of 2 features each
labels = np.empty(100)
for i in range(100):
if features[i,0] < 0:
if features[i,1] < -1:
labels[i] = np.random.rand() < .001
else:
labels[i] = np.random.rand() < .999
else:
if features[i,1] < 1:
labels[i] = np.random.rand() < .001
else:
labels[i] = np.random.rand() < .999
# <markdowncell>
# What is the decision tree for this data?
# <markdowncell>
# Since the data is two-dimensional, we can take a look at it easily.
# <codecell>
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=6, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()
# <markdowncell>
# Fitting the model is easy:
# <codecell>
learner = milk.supervised.tree_learner()
model = learner.train(features, labels)
# <markdowncell>
# Using it is easy, too:
# <codecell>
model.apply([-1,1])
# <markdowncell>
# Visualizing the decision boundary is a bit of a pain...
# <codecell>
x_range = np.linspace(-3,3,100)
y_range = np.linspace(-3,3,100)
val = np.zeros((len(x_range), len(y_range)))
for i, x_i in enumerate(x_range):
for j, y_j in enumerate(y_range):
val[i,j] = model.apply([x_i,y_j])
imshow(val[::1,::-1].T, extent=[x_range[0],x_range[-1],y_range[0],y_range[-1]], cmap=cm.Greys)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='w', ms=7, mew=5)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=5, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()
# <markdowncell>
# And can we have a picture of the decision tree itself? It is hidden in the model instance somewhere...
# <codecell>
model.tree
# <codecell>
model.tree.featid, model.tree.featval, model.tree.left, model.tree.right
# <codecell>
def describe_tree(node, prefix=''):
print prefix + 'if x[%d] < %.2f:' % (node.featid, node.featval)
if isinstance(node.left, milk.supervised.tree.Node):
describe_tree(node.left, prefix+' ')
else:
print prefix+' ', node.left
print prefix + 'else:'
if isinstance(node.right, milk.supervised.tree.Node):
describe_tree(node.right, prefix+' ')
else:
print prefix+' ', node.right
# <codecell>
describe_tree(model.tree)
# <markdowncell>
# Not as simple as it seemed in the lecture, huh?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment