Created
September 23, 2017 10:51
-
-
Save sumartoyo/c78fa3d4f7e09e20102278523aaa9f34 to your computer and use it in GitHub Desktop.
python numpy implementation of https://www.kaggle.com/wiki/MultiClassLogLoss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def logloss(y_true, prob_pred): | |
''' | |
https://www.kaggle.com/wiki/MultiClassLogLoss | |
''' | |
y_true = np.asarray(y_true, dtype=np.uint8) | |
prob_pred = np.array(prob_pred, dtype=np.float) | |
n_data = len(y_true) | |
# rescale pred | |
row_sum = np.sum(prob_pred, axis=1, keepdims=True) | |
np.clip(row_sum, 1e-15, None, out=row_sum) | |
np.divide(prob_pred, row_sum, out=prob_pred) | |
# clip extremes | |
np.clip(prob_pred, 1e-15, 1-1e-15, out=prob_pred) | |
score = -(np.sum(np.log(prob_pred) * y_true) / n_data) | |
return score | |
def test_logloss(): | |
y_true = [ | |
[1, 0], | |
[1, 0], | |
[1, 0], | |
[0, 1], | |
[0, 1], | |
[0, 1], | |
] | |
prob_pred = [ | |
[0.5, 0.5], | |
[0.1, 0.9], | |
[0.01, 0.99], | |
[0.9, 0.1], | |
[0.75, 0.25], | |
[0.001, 0.999], | |
] | |
assert logloss(y_true, prob_pred) == 1.881797068998267 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment