Skip to content

Instantly share code, notes, and snippets.

@bbengfort
Created March 8, 2018 12:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbengfort/bd524672aff751f4340be58833f256ec to your computer and use it in GitHub Desktop.
Save bbengfort/bd524672aff751f4340be58833f256ec to your computer and use it in GitHub Desktop.
Class balance of y_true vs y_pred
import numpy as np
import matplotlib.pyplot as plt
import yellowbrick as yb # For the styles
from sklearn.base import clone
from sklearn.model_selection import KFold
from sklearn.metrics.classification import _check_targets
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_consistent_length
from sklearn.externals.joblib import Parallel, delayed
def plot_target(y_true, y_pred, labels=None, ax=None, width=0.35, **kwargs):
# Validate the input
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if y_type not in ("binary", "multiclass"):
raise ValueError("%s is not supported" % y_type)
# This is probably not necessary
check_consistent_length(y_true, y_pred)
# Manage the labels passed in (yb might use classes for this arg)
if labels is None:
labels = unique_labels(y_true, y_pred)
else:
labels = np.asarray(labels)
if np.all([l not in y_true for l in labels]):
raise ValueError("At least one label specified must be in y_true")
# Count the values of y_true and y_pred for each class
indices = np.arange(0, labels.shape[0])
# This expects labels to be numerically encoded, not strings
# YB needs to handle either case better, though _check_targets
# may deal with this, I'm not sure - need to review the code.
# Needless to say this is a HACK that needs to be addressed.
t_counts = np.array([(y_true==label).sum() for label in indices])
p_counts = np.array([(y_pred==label).sum() for label in indices])
# Begin the figure
if ax is None:
_, ax = plt.subplots()
b1 = ax.bar(indices, t_counts, width, color='b', label="actual")
b2 = ax.bar(indices+width, p_counts, width, color='g', label="predicted")
ax.set_xticks(indices + width/2)
ax.set_xticklabels(labels)
ax.set_xlabel("class")
ax.set_ylabel("number of instances")
ax.legend(loc='best', frameon=True)
ax.grid(False, axis='x')
return ax
def _cross_validate(model, X, y, cv=6, n_jobs=6):
"""
Returns y_true and y_pred for all instances using cross-validation
"""
# HACK: should use sklearn method to get split indices
folds = KFold(cv)
splits = Parallel(n_jobs=n_jobs)(
delayed(_split_validate)(model, X, y, train, test)
for train, test in folds.split(X, y)
)
y_true = []
y_pred = []
for y_true_split, y_pred_split in splits:
y_true.append(y_true_split)
y_pred.append(y_pred_split)
return np.concatenate(y_true), np.concatenate(y_pred)
def _split_validate(model, X, y, train, test):
X_train, X_test = X[train], X[test]
y_train, y_true = y[train], y[test]
est = clone(model)
est.fit(X_train, y_train)
return y_true, est.predict(X_test)
if __name__ == "__main__":
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=200, n_features=100,
n_informative=20, n_redundant=10,
n_classes=6, random_state=42)
y_true, y_pred = _cross_validate(LogisticRegression(), X, y)
plot_target(y_true, y_pred)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment