-
-
Save dsleo/a607816521e9e39b974df77528b0e7f1 to your computer and use it in GitHub Desktop.
Histogram Calibration Error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.utils.validation import column_or_1d | |
from sklearn.utils import assert_all_finite, check_consistent_length | |
def calibration_loss(y_true, y_prob, sample_weight=None, norm="l2", | |
n_bins=10, pos_label=None, reduce_bias=True): | |
"""Compute calibration loss. | |
Across all items in a set of N predictions, the calibration loss measures | |
the aggregated difference between (1) the average predicted probabilities | |
assigned to the positive class, and (2) the frequencies | |
of the positive class in the actual outcome. | |
The calibration loss is appropriate for binary and categorical outcomes | |
that can be structured as true or false. | |
Which label is considered to be the positive label is controlled via the | |
parameter pos_label, which defaults to 1. | |
Read more in the :ref:`User Guide <calibration>`. | |
Parameters | |
---------- | |
y_true : array, shape (n_samples,) | |
True targets. | |
y_prob : array, shape (n_samples,) | |
Probabilities of the positive class. | |
sample_weight : array-like, shape (n_samples,), optional | |
Sample weights. | |
norm : 'l1' | 'l2' | 'max' | |
Norm method. | |
n_bins : int, optional (default=10) | |
The number of bins to compute error on. | |
pos_label : int or str, optional (default=None) | |
Label of the positive class. If None, the maximum label is used as | |
positive class | |
reduce_bias : bool, optional (default=True) | |
Add debiasing term as in Verified Uncertainty Calibration, A. Kumar | |
Returns | |
------- | |
score : float | |
calibration loss | |
""" | |
y_true = column_or_1d(y_true) | |
y_prob = column_or_1d(y_prob) | |
assert_all_finite(y_true) | |
assert_all_finite(y_prob) | |
check_consistent_length(y_true, y_prob, sample_weight) | |
if any(y_prob < 0) or any(y_prob > 1): | |
raise ValueError("y_prob has values outside of [0, 1] range") | |
if pos_label is None: | |
pos_label = y_true.max() | |
y_true = np.array(y_true == pos_label, int) | |
loss = 0. | |
count = 0. | |
debias = 0. | |
remapping = np.argsort(y_prob) | |
y_true = y_true[remapping] | |
y_prob = y_prob[remapping] | |
if sample_weight is not None: | |
sample_weight = sample_weight[remapping] | |
i_thres = np.searchsorted(y_prob, | |
np.arange(0, 1, 1./n_bins)).tolist() | |
i_thres.append(y_true.shape[0]) | |
for i, i_start in enumerate(i_thres[:-1]): | |
i_end = i_thres[i+1] | |
if sample_weight is None: | |
delta_count = float(i_end - i_start) | |
avg_pred_true = y_true[i_start:i_end].sum() / delta_count | |
bin_centroid = y_prob[i_start:i_end].sum() / delta_count | |
else: | |
delta_count = float(sample_weight[i_start:i_end].sum()) | |
avg_pred_true = (np.dot(y_true[i_start:i_end], | |
sample_weight[i_start:i_end]) | |
/ delta_count) | |
bin_centroid = (np.dot(y_prob[i_start:i_end], | |
sample_weight[i_start:i_end]) | |
/ delta_count) | |
count += delta_count | |
if reduce_bias: | |
norm = "l2" | |
delta_debias = avg_pred_true*(avg_pred_true-1) * delta_count | |
delta_debias /= y_true.shape[0]*delta_count-1 | |
if not np.isnan(delta_debias): | |
debias += delta_debias | |
if norm == "max": | |
loss = max(loss, abs(avg_pred_true - bin_centroid)) | |
elif norm == "l1": | |
delta_loss = abs(avg_pred_true - bin_centroid) * delta_count | |
if not np.isnan(delta_loss): | |
loss += delta_loss | |
elif norm == "l2": | |
delta_loss = (avg_pred_true - bin_centroid)**2 * delta_count | |
if not np.isnan(delta_loss): | |
loss += delta_loss | |
else: | |
raise ValueError("norm is neither 'l1', 'l2' nor 'max'") | |
if norm == "l1": | |
loss /= count | |
if norm == "l2": | |
loss /= count | |
if reduce_bias: | |
loss += debias | |
loss = np.sqrt(max(loss, 0.)) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment