Skip to content

Instantly share code, notes, and snippets.

@tianhuil
Created January 1, 2014 21:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tianhuil/8211518 to your computer and use it in GitHub Desktop.
Save tianhuil/8211518 to your computer and use it in GitHub Desktop.
Learning Curves for trees of different depths
import numpy as np
from collections import namedtuple
# Create a random dataset
rng = np.random.RandomState(42)
N_points = 100
X = np.sort(5 * rng.rand(N_points, 1), axis=0)
y = np.sin(X).ravel()
y += .4 * (0.5 - rng.rand(N_points))
# sklearn imports
from sklearn.tree import DecisionTreeRegressor
from sklearn import cross_validation
# plot models
max_depths = [2, 4, 8]
X_plot = np.arange(0.0, 5.0, 0.005)[:, np.newaxis]
y_plots = [DecisionTreeRegressor(max_depth=max_depth).fit(X, y).predict(X_plot) for max_depth in max_depths]
# plot training data
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.2, random_state=0)
training_points = np.array([.1 , .2, .4, .6, .8]) * N_points
# shuffle training data
data = zip(X_train.T[0], y_train)
np.random.shuffle(data)
X_train = np.array([[x[0] for x in data]]).T
y_train = np.array([x[1] for x in data])
# compute learning curves
LearningCurveData = namedtuple('LearningCurveData', ['n', 'training_err', 'test_err'])
def get_learning_curve_data(clf, n):
clf = clf.fit(X_train[:n], y_train[:n])
y_pred = clf.predict(X_test)
y_train_pred = clf.predict(X_train[:n])
return LearningCurveData(n, (y_train_pred - y_train[:n]).std(), (y_pred-y_test).std())
learning_curve_data = [
[get_learning_curve_data(DecisionTreeRegressor(max_depth=max_depth), n) for n in training_points]
for max_depth in max_depths]
# Plot the results
import pylab as pl
pl.figure()
colors = ["r", "g", "b"]
for y_plot, learning_curve_datum, max_depth, k, color in zip(y_plots, learning_curve_data, max_depths, range(3), colors):
pl.subplot(3, 2, 1 + 2 * k)
pl.scatter(X, y, c="k", label="data")
pl.plot(X_plot, y_plot, color=color, linewidth=2)
pl.title('max_depth = %s' % max_depth)
if k == 2:
pl.xlabel('X')
pl.ylabel('y')
pl.subplot(3, 2, 2 + 2 * k)
ns = [d.n for d in learning_curve_datum]
training_errs = [d.training_err for d in learning_curve_datum]
test_errs = [d.test_err for d in learning_curve_datum]
pl.plot(ns, training_errs, '--' + color, linewidth=2)
pl.plot(ns, test_errs, color, linewidth=2)
pl.title('max_depth = %s' % max_depth)
if k == 2:
pl.xlabel('training examples')
pl.ylabel('error')
pl.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment