Last active
January 1, 2016 22:39
-
-
Save tianhuil/8211581 to your computer and use it in GitHub Desktop.
Variance-Bias Plots
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from collections import namedtuple | |
# Create a random dataset | |
rng = np.random.RandomState(42) | |
N_points = 10000 | |
X = np.sort(5 * rng.rand(N_points, 1), axis=0) | |
y = np.sin(X).ravel() | |
y += .4 * (0.5 - rng.rand(N_points)) | |
# sklearn imports | |
from sklearn.tree import DecisionTreeRegressor | |
from sklearn import cross_validation | |
# split training / test data | |
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.2, random_state=0) | |
training_points = np.array([.1 , .2, .4, .6, .8]) * N_points | |
max_depths = range(1,15) | |
# shuffle training data | |
data = zip(X_train.T[0], y_train) | |
np.random.shuffle(data) | |
X_train = np.array([[x[0] for x in data]]).T | |
y_train = np.array([x[1] for x in data]) | |
# compute variance-bias errors | |
Errors = namedtuple('Errors', ['max_depth', 'training_err', 'test_err']) | |
def get_errors(clf): | |
clf = clf.fit(X_train, y_train) | |
y_pred = clf.predict(X_test) | |
y_train_pred = clf.predict(X_train) | |
return Errors(n, (y_train_pred - y_train).std(), (y_pred-y_test).std()) | |
errors = [get_errors(DecisionTreeRegressor(max_depth=max_depth)) for max_depth in max_depths] | |
# Plot the results | |
import pylab as pl | |
pl.figure() | |
colors = ["r", "g", "b"] | |
training_errs = [e.training_err for e in errors] | |
test_errs = [e.test_err for e in errors] | |
pl.plot(max_depths, training_errs, '--' + color, label='training', linewidth=2) | |
pl.plot(max_depths, test_errs, color, label='test', linewidth=2) | |
pl.xlabel('max_depths') | |
pl.ylabel('error') | |
pl.title('variance-bias') | |
pl.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment