Last active
April 13, 2025 21:09
-
-
Save soodoku/7280fc2f6ddc6a1da7cf01d2d25565e3 to your computer and use it in GitHub Desktop.
Vanilla Xgboost miscalibration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import make_classification | |
from sklearn.model_selection import train_test_split | |
from sklearn.calibration import calibration_curve | |
import xgboost as xgb | |
# Generate imbalanced data (10% positive class) | |
X, y = make_classification( | |
n_samples=100000, | |
n_features=10, | |
n_informative=5, | |
n_redundant=2, | |
weights=[0.9, 0.1], # Create class imbalance: 90% class 0, 10% class 1 | |
random_state=42 | |
) | |
# Split data into training and test sets | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) | |
# Train an XGBoost classifier | |
xgb_model = xgb.XGBClassifier( | |
n_estimators=100, | |
learning_rate=0.1, | |
max_depth=5, | |
objective='binary:logistic', | |
random_state=42 | |
) | |
# Train the model | |
xgb_model.fit(X_train, y_train) | |
# Get predicted probabilities | |
y_pred_prob = xgb_model.predict_proba(X_test)[:, 1] | |
# Calculate calibration curve (reliability diagram) | |
prob_true, prob_pred = calibration_curve(y_test, y_pred_prob, n_bins=10) | |
# Plot the calibration curve | |
plt.figure(figsize=(10, 8)) | |
plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='XGBoost') | |
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly calibrated') | |
plt.grid(True) | |
plt.title('Calibration Curve on Imbalanced Data (10% positive class)', fontsize=14) | |
plt.xlabel('Mean Predicted Probability', fontsize=12) | |
plt.ylabel('Fraction of Positives (Actual Probability)', fontsize=12) | |
plt.legend(fontsize=12) | |
plt.tight_layout() | |
plt.show() | |
# Print class distribution information | |
print(f"Class distribution in training data: {np.bincount(y_train)}") | |
print(f"Class distribution in test data: {np.bincount(y_test)}") | |
print(f"Percentage of class 1 in training: {np.mean(y_train)*100:.1f}%") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment