This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+ ------------+----------+--------+-----------------------------------------+--------+-----+-------+-------+-----------+---------+-------+----------+ | |
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | | |
+ ------------+----------+--------+-----------------------------------------+--------+-----+-------+-------+-----------+---------+-------+----------| | |
| 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22 | 1 | 0 | A/5 21171 | 7.25 | nan | S | | |
| 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence... | female | 38 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | | |
| 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26 | 0 | 0 | STON/O2. | 7.925 | nan | S | | |
| 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily M... | female | 35 | 1 | 0 | 113803 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_output_shape(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
dt = DecisionTree() | |
dt.fit(X_train, y_train) | |
pred_train = dt.predict(X_train) | |
pred_test = dt.predict(X_test) | |
assert pred_train.shape == (X_train.shape[0],), 'DecisionTree output should be same as training labels.' | |
assert pred_test.shape == (X_test.shape[0],), 'DecisionTree output should be same as testing labels.' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_gini_impurity(): | |
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 1, 1]), 3) == 0 | |
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 1, 0]), 3) == 0.219 | |
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 0, 0]), 3) == 0.375 | |
assert round(gini_impurity([1, 1, 1, 1, 1, 0, 0, 0]), 3) == 0.469 | |
assert round(gini_impurity([1, 1, 1, 1, 0, 0, 0, 0]), 3) == 0.500 | |
def test_gini_gain(): | |
assert round(gini_gain([1, 1, 1, 1, 0, 0, 0, 0], [[1, 1, 1, 1], [0, 0, 0, 0]]), 3) == 0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_output_range(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
dt = DecisionTree() | |
dt.fit(X_train, y_train) | |
pred_train = dt.predict(X_train) | |
pred_test = dt.predict(X_test) | |
assert (pred_train <= 1).all() & (pred_train >= 0).all(), 'Decision tree output should range from 0 to 1 inclusive' | |
assert (pred_test <= 1).all() & (pred_test >= 0).all(), 'Decision tree output should range from 0 to 1 inclusive' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_data_leak_in_test_data(dummy_titanic_df): | |
train, test = dummy_titanic_df | |
concat_df = pd.concat([train, test]) | |
concat_df.drop_duplicates(inplace=True) | |
assert concat_df.shape[0] == train.shape[0] + test.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.fixture | |
def dummy_feats_and_labels(): | |
feats = np.array([[0.7057, -5.4981, 8.3368, -2.8715], | |
[2.4391, 6.4417, -0.80743, -0.69139], | |
[-0.2062, 9.2207, -3.7044, -6.8103], | |
[4.2586, 11.2962, -4.0943, -4.3457], | |
[-2.343, 12.9516, 3.3285, -5.9426], | |
[-2.0545, -10.8679, 9.4926, -1.4116], | |
[2.2279, 4.0951, -4.8037, -2.1112], | |
[-6.1632, 8.7096, -0.21621, -3.6345], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_increase_acc(dummy_titanic): | |
X_train, y_train, _, _ = dummy_titanic | |
acc_list = [] | |
auc_list = [] | |
for depth in range(1, 10): | |
dt = DecisionTree(depth_limit=depth) | |
dt.fit(X_train, y_train) | |
pred = dt.predict(X_train) | |
pred_binary = np.round(pred) | |
acc_list.append(accuracy_score(y_train, pred_binary)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pytest.fixture | |
def dummy_passengers(): | |
# Based on passenger 2 (high passenger class female) | |
passenger2 = {'PassengerId': 2, | |
'Pclass': 1, | |
'Name': ' Mrs. John', | |
'Sex': 'female', | |
'Age': 38.0, | |
'SibSp': 1, | |
'Parch': 0, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_directional_expectation(dummy_titanic_dt, dummy_passengers): | |
model = dummy_titanic_dt | |
_, p2 = dummy_passengers | |
# Get original survival probability of passenger 2 | |
test_df = pd.DataFrame.from_dict([p2], orient='columns') | |
X, y = get_feats_and_labels(prep_df(test_df)) | |
p2_prob = model.predict(X)[0] # 1.0 | |
# Change gender from female to male |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_dt_evaluation(dummy_titanic_dt, dummy_titanic): | |
model = dummy_titanic_dt | |
X_train, y_train, X_test, y_test = dummy_titanic | |
pred_test = model.predict(X_test) | |
pred_test_binary = np.round(pred_test) | |
acc_test = accuracy_score(y_test, pred_test_binary) | |
auc_test = roc_auc_score(y_test, pred_test) | |
assert acc_test > 0.82, 'Accuracy on test should be > 0.82' | |
assert auc_test > 0.84, 'AUC ROC on test should be > 0.84' |
OlderNewer