Skip to content

Instantly share code, notes, and snippets.

@glemaitre
Created May 27, 2020 12:33
Show Gist options
  • Save glemaitre/4f56118e42018c1906fff7744a2d0fac to your computer and use it in GitHub Desktop.
Save glemaitre/4f56118e42018c1906fff7744a2d0fac to your computer and use it in GitHub Desktop.
@pytest.mark.parametrize("name, Tree", REG_TREES.items())
@pytest.mark.parametrize("criterion", REG_CRITERIONS)
def test_diabetes_overfit(name, Tree, criterion):
# check consistency of overfitted trees on the diabetes dataset
# since the trees will overfit, we expect an MSE of 0
reg = Tree(criterion=criterion, random_state=0)
reg.fit(diabetes.data, diabetes.target)
score = mean_squared_error(diabetes.target, reg.predict(diabetes.data))
assert score == pytest.approx(0), (
f"Failed with {name}, criterion = {criterion} and score = {score}"
)
@pytest.mark.parametrize("name, Tree", REG_TREES.items())
@pytest.mark.parametrize(
"criterion, max_depth",
[("mse", 15), ("mae", 20), ("friedman_mse", 15)]
)
def test_diabetes_underfit(name, Tree, criterion, max_depth):
# check consistency of trees when the depth and the number of features are
# limited
reg = Tree(
criterion=criterion, max_depth=max_depth,
max_features=6, random_state=0
)
reg.fit(diabetes.data, diabetes.target)
score = mean_squared_error(diabetes.target, reg.predict(diabetes.data))
assert score < 60 and score > 0, (
f"Failed with {name}, criterion = {criterion} and score = {score}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment