Skip to content

Instantly share code, notes, and snippets.

@soodoku
Last active April 13, 2025 21:09
Show Gist options
  • Save soodoku/7280fc2f6ddc6a1da7cf01d2d25565e3 to your computer and use it in GitHub Desktop.
Save soodoku/7280fc2f6ddc6a1da7cf01d2d25565e3 to your computer and use it in GitHub Desktop.
Vanilla Xgboost miscalibration
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
import xgboost as xgb
# Generate imbalanced data (10% positive class)
X, y = make_classification(
n_samples=100000,
n_features=10,
n_informative=5,
n_redundant=2,
weights=[0.9, 0.1], # Create class imbalance: 90% class 0, 10% class 1
random_state=42
)
# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train an XGBoost classifier
xgb_model = xgb.XGBClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=5,
objective='binary:logistic',
random_state=42
)
# Train the model
xgb_model.fit(X_train, y_train)
# Get predicted probabilities
y_pred_prob = xgb_model.predict_proba(X_test)[:, 1]
# Calculate calibration curve (reliability diagram)
prob_true, prob_pred = calibration_curve(y_test, y_pred_prob, n_bins=10)
# Plot the calibration curve
plt.figure(figsize=(10, 8))
plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='XGBoost')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly calibrated')
plt.grid(True)
plt.title('Calibration Curve on Imbalanced Data (10% positive class)', fontsize=14)
plt.xlabel('Mean Predicted Probability', fontsize=12)
plt.ylabel('Fraction of Positives (Actual Probability)', fontsize=12)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()
# Print class distribution information
print(f"Class distribution in training data: {np.bincount(y_train)}")
print(f"Class distribution in test data: {np.bincount(y_test)}")
print(f"Percentage of class 1 in training: {np.mean(y_train)*100:.1f}%")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment