This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_increase_acc(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
acc_list = [] | |
auc_list = [] | |
for num_trees in [1, 3, 7, 15]: | |
rf = RandomForest(num_trees=num_trees, depth_limit=7, col_subsampling=0.7, row_subsampling=0.7) | |
rf.fit(X_train, y_train) | |
pred = rf.predict(X_test) | |
pred_binary = np.round(pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_training_time(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
# Standardize to use depth = 10 | |
dt = DecisionTree(depth_limit=10) | |
latency_array = np.array([train_with_time(dt, X_train, y_train)[1] for i in range(100)]) | |
time_p95 = np.quantile(latency_array, 0.95) | |
assert time_p95 < 1.0, 'Training time at 95th percentile should be < 1.0 sec' | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_evaluation(dummy_titanic_dt, dummy_titanic): | |
model = dummy_titanic_dt | |
X_train, y_train, X_test, y_test = dummy_titanic | |
pred_test = model.predict(X_test) | |
pred_test_binary = np.round(pred_test) | |
acc_test = accuracy_score(y_test, pred_test_binary) | |
auc_test = roc_auc_score(y_test, pred_test) | |
assert acc_test > 0.82, 'Accuracy on test should be > 0.82' | |
assert auc_test > 0.84, 'AUC ROC on test should be > 0.84' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_directional_expectation(dummy_titanic_dt, dummy_passengers): | |
model = dummy_titanic_dt | |
_, p2 = dummy_passengers | |
# Get original survival probability of passenger 2 | |
test_df = pd.DataFrame.from_dict([p2], orient='columns') | |
X, y = get_feats_and_labels(prep_df(test_df)) | |
p2_prob = model.predict(X)[0] # 1.0 | |
# Change gender from female to male |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.fixture | |
def dummy_passengers(): | |
# Based on passenger 2 (high passenger class female) | |
passenger2 = {'PassengerId': 2, | |
'Pclass': 1, | |
'Name': ' Mrs. John', | |
'Sex': 'female', | |
'Age': 38.0, | |
'SibSp': 1, | |
'Parch': 0, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_increase_acc(dummy_titanic): | |
X_train, y_train, _, _ = dummy_titanic | |
acc_list = [] | |
auc_list = [] | |
for depth in range(1, 10): | |
dt = DecisionTree(depth_limit=depth) | |
dt.fit(X_train, y_train) | |
pred = dt.predict(X_train) | |
pred_binary = np.round(pred) | |
acc_list.append(accuracy_score(y_train, pred_binary)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.fixture | |
def dummy_feats_and_labels(): | |
feats = np.array([[0.7057, -5.4981, 8.3368, -2.8715], | |
[2.4391, 6.4417, -0.80743, -0.69139], | |
[-0.2062, 9.2207, -3.7044, -6.8103], | |
[4.2586, 11.2962, -4.0943, -4.3457], | |
[-2.343, 12.9516, 3.3285, -5.9426], | |
[-2.0545, -10.8679, 9.4926, -1.4116], | |
[2.2279, 4.0951, -4.8037, -2.1112], | |
[-6.1632, 8.7096, -0.21621, -3.6345], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_data_leak_in_test_data(dummy_titanic_df): | |
train, test = dummy_titanic_df | |
concat_df = pd.concat([train, test]) | |
concat_df.drop_duplicates(inplace=True) | |
assert concat_df.shape[0] == train.shape[0] + test.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_output_range(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
dt = DecisionTree() | |
dt.fit(X_train, y_train) | |
pred_train = dt.predict(X_train) | |
pred_test = dt.predict(X_test) | |
assert (pred_train <= 1).all() & (pred_train >= 0).all(), 'Decision tree output should range from 0 to 1 inclusive' | |
assert (pred_test <= 1).all() & (pred_test >= 0).all(), 'Decision tree output should range from 0 to 1 inclusive' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_output_shape(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
dt = DecisionTree() | |
dt.fit(X_train, y_train) | |
pred_train = dt.predict(X_train) | |
pred_test = dt.predict(X_test) | |
assert pred_train.shape == (X_train.shape[0],), 'DecisionTree output should be same as training labels.' | |
assert pred_test.shape == (X_test.shape[0],), 'DecisionTree output should be same as testing labels.' |