Skip to content

Instantly share code, notes, and snippets.

@malleor
Last active August 29, 2015 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malleor/52b9f1b71fc02037060c to your computer and use it in GitHub Desktop.
Save malleor/52b9f1b71fc02037060c to your computer and use it in GitHub Desktop.
Playing with Decision Trees from scikit
import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from matplotlib import pyplot as plt
# fetch data
iris = load_iris()
classes = {}
for cls in (0,1,2):
classes[cls] = np.array([v for k, v in zip(iris.target, iris.data) if k == cls])
f1, f2 = 0, 3 # selected features
form_feature = lambda v: (v[f1], v[f2])
def randomize(seq):
np.random.shuffle(seq)
return seq
# BUG IN THE FOLLOWING LINE!
form_features = lambda id, values: randomize(np.array([[form_feature(v) for v in values], [id]*len(values)]))
X = np.zeros((0,2), dtype=float)
y = np.zeros((0), dtype=float)
for id, values in classes.iteritems():
samples = form_features(id, values)
X = np.append(X, pX, axis=0)
y = np.append(y, py, axis=0)
# plot data
plt.figure('iris')
for i, cls in enumerate(classes.itervalues()):
plt.scatter(cls[:,f1], cls[:,f2], c=['r','b','g'][i])
plt.show(block=False)
# learn classes
# test the classifier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment