Last active
July 23, 2020 16:39
-
-
Save richardtomsett/8b814f30e1d665fae2b4085d3e4156f5 to your computer and use it in GitHub Desktop.
Reproducing an issue with predict_proba() returning NaN values for linear classifiers with log loss in scikit learn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sklearn | |
import sklearn.metrics | |
from sklearn.datasets import make_classification | |
from sklearn.linear_model import SGDClassifier | |
import numpy as np | |
# Create some training and testing data | |
data_size = 10000 | |
train_size = 9000 | |
test_size = data_size - train_size | |
X, y = make_classification(n_samples=data_size, | |
n_features=500, | |
n_informative=500, | |
n_redundant=0, | |
n_repeated=0, | |
n_classes=10, | |
n_clusters_per_class=1, | |
weights=None, | |
flip_y=0, | |
class_sep=1, | |
hypercube=True, | |
shift=0.0, | |
scale=1.0, | |
shuffle=True, | |
random_state=42) | |
X_test = X[train_size:,:] | |
y_test = y[train_size:] | |
X = X[:train_size] | |
y = y[:train_size] | |
# Standardize the data | |
X_mean = np.mean(X,axis=0) | |
X_std = np.std(X,axis=0) | |
X = (X - X_mean) / X_std | |
X_test = (X_test - X_mean) / X_std | |
# Create an SGDClassifier using log loss, default parameters | |
model = SGDClassifier(loss="log") | |
# Do a partial fit on a minibatch of 50 samples | |
# NB: if you try with the data generated above, but train on e.g. the first 42 samples, | |
# then no NaNs are returned... | |
model.partial_fit(X[0:50,:], y[0:50], classes=np.arange(10)) | |
# Predict the class probabilities for all of the test data | |
y_prob = model.predict_proba(X_test) | |
# Print whether any of the probabilities are NaN | |
print(np.where(np.isnan(y_prob))) | |
# This creates a downstream error when calculating the log_loss with sklearn.metrics.log_loss: | |
loss = sklearn.metrics.log_loss(y_test, y_prob) | |
print(loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment