Skip to content

Instantly share code, notes, and snippets.

@eugeneyan
Created February 21, 2021 19:16
Show Gist options
  • Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.
Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.
Test RandomForest performs better with same depth
def test_rf_better_than_dt(dummy_titanic):
X_train, y_train, X_test, y_test = dummy_titanic
dt = DecisionTree(depth_limit=10)
dt.fit(X_train, y_train)
rf = RandomForest(depth_limit=10, num_trees=7, col_subsampling=0.8, row_subsampling=0.8)
rf.fit(X_train, y_train)
pred_test_dt = dt.predict(X_test)
pred_test_binary_dt = np.round(pred_test_dt)
acc_test_dt = accuracy_score(y_test, pred_test_binary_dt)
auc_test_dt = roc_auc_score(y_test, pred_test_dt)
pred_test_rf = rf.predict(X_test)
pred_test_binary_rf = np.round(pred_test_rf)
acc_test_rf = accuracy_score(y_test, pred_test_binary_rf)
auc_test_rf = roc_auc_score(y_test, pred_test_rf)
assert acc_test_rf > acc_test_dt, 'RandomForest should have higher accuracy than DecisionTree on test set.'
assert auc_test_rf > auc_test_dt, 'RandomForest should have higher AUC ROC than DecisionTree on test set.'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment