Skip to content

Instantly share code, notes, and snippets.

@dtrizna
Last active September 21, 2022 06:28
Show Gist options
  • Save dtrizna/eabf6d9be2862afcf40e92c76ea3d6f0 to your computer and use it in GitHub Desktop.
Save dtrizna/eabf6d9be2862afcf40e92c76ea3d6f0 to your computer and use it in GitHub Desktop.
from sklearn.model_selection import cross_validate
from sklearn.model_selection import StratifiedKFold
def print_scores(cv):
means = np.mean(list(cv.values()), axis=1)
[print(f"\tAverage {x[0].strip('test_'):<10} over all folds: {x[1]:.2f}") for x in zip(cv.keys(), means) if "test_" in x[0]]
print()
cv = {}
metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
for key in ["HashingVectorizer", "TfidfVectorizer"]:
xgb_model = XGBClassifier(n_estimators=100, use_label_encoder=False, eval_metric="logloss")
skf = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
cv[key] = cross_validate(xgb_model, X[key], y, cv=skf, scoring=metrics)
print(f"{key}:")
print_scores(cv[key])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment