Skip to content

Instantly share code, notes, and snippets.

@deaktator
Last active February 26, 2022 22:18
Show Gist options
  • Save deaktator/47bf3835fa0652f429fe7bf4f3cec606 to your computer and use it in GitHub Desktop.
Save deaktator/47bf3835fa0652f429fe7bf4f3cec606 to your computer and use it in GitHub Desktop.
Classification Metrics for Score Histograms
from typing import List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import scipy
def binned_cm_stats(histogram: Union[np.ndarray, List[int]],
bins: Union[np.ndarray, List[float]]) -> pd.DataFrame:
"""Produce confusion matrix statistics for a histogram of probability estimates.
Parameters:
histogram (Union[numpy.ndarray, List[int]]): Array of counts. Length ``N``.
bins (Union[numpy.ndarray, List[float]]): Array of histogram bucket endpoints.
Length ``N + 1``.
Returns:
pandas.DataFrame: Contains 'thresh' column equal to ``bins[:-1]`` and associated
confusion matrix based stats assuming the 'thresh' value is
the decision boundary. The returned dataframe has ``N`` rows.
"""
thresh = bins[:-1]
# Rectangle areas: x is the bin center, y is the bucket count.
rect_areas = (thresh + np.diff(bins) / 2) * histogram
fn = np.cumsum(rect_areas)
tp = np.cumsum(rect_areas[::-1])[::-1]
n = np.cumsum(histogram)
p = np.cumsum(histogram[::-1])[::-1]
tn = n - fn
fp = p - tp
# Confusion matrix cells for each threshold / histogram bin.
df = pd.DataFrame({'thresh': thresh, 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn})
df['tpr'] = tpr(df)
df['fpr'] = fpr(df)
df['tnr'] = tnr(df)
df['fnr'] = fnr(df)
df['recall'] = recall(df)
df['precision'] = precision(df)
df['accuracy'] = accuracy(df)
df['f1'] = f_beta(df)
return df
def cm_stats_by_threshold_binned(yp: Union[np.ndarray, List[float]],
digits_precision: int = 3) -> pd.DataFrame:
"""Produce confusion matrix statistics for a histogram of probability estimates.
Parameters:
yp (Union[numpy.ndarray, List[float]]): Array of model scores.
digits_precision (float): Number of digits of precision in histogram bins (default 3)
Returns:
pandas.DataFrame: Contains 'thresh' column equal to ``bins[:-1]`` and associated
confusion matrix based stats assuming the 'thresh' value is
the decision boundary. The returned dataframe has ``N`` rows.
"""
bins = np.round(np.linspace(0, 1, 10 ** digits_precision + 1), digits_precision)
hist, _ = np.histogram(yp, bins=bins)
return binned_cm_stats(hist, bins)
class CmStatsBinned:
@staticmethod
def _integrate(x: Union[np.ndarray, List[float]],
y: Union[np.ndarray, List[float]],
extrapolate_to: Optional[Tuple[float, float]] = None) -> float:
"""Numerically integrate the points ``{ (x_i,y_i) }``.
If ``extrapolate_to`` is set, horizontal lines will be added from:
* ``extrapolate_to[0]`` to ``min(x)`` with the y value associated with
the minimum x value.
* ``max(x)`` to ``extrapolate_to[-1]`` with the y value associated with
the maximum x value.
Parameters:
x (Union[numpy.ndarray, List[float]]): x values. Doesn't need to be sorted.
y (Union[numpy.ndarray, List[float]]): y values (same length as ``x``)
extrapolate_to (Optional[Tuple[float, float]], default: ``None``).
Returns:
pandas.DataFrame: Contains 'thresh' column equal to ``bins[:-1]`` and associated
confusion matrix based stats assuming the 'thresh' value is
the decision boundary. The returned dataframe has ``N`` rows.
"""
x = np.array(x)
y = np.array(y)
idx = np.argsort(x)
x = x[idx]
y = y[idx]
if extrapolate_to is not None:
a, b = extrapolate_to
if a < x[0]:
x = np.concatenate(([a], x))
y = np.concatenate(([y[0]], y))
if x[-1] < b:
x = np.concatenate((x, [b]))
y = np.concatenate((y, [y[-1]]))
return scipy.integrate.trapz(y, x)
@staticmethod
def auc(df: pd.DataFrame) -> float:
# df should be output from binned_cm_stats
return CmStatsBinned._integrate(df.fpr, df.tpr, extrapolate_to=(0, 1))
@staticmethod
def ap(df: pd.DataFrame) -> float:
# df should be output from binned_cm_stats
return CmStatsBinned._integrate(df.recall, df.precision, extrapolate_to=(0, 1))
# ============================================================================
# confusion matrix statistics
# ============================================================================
def tpr(cm):
return cm['tp'] / (cm['tp'] + cm['fn'])
def fpr(cm):
return cm['fp'] / (cm['fp'] + cm['tn'])
def tnr(cm):
return cm['tn'] / (cm['tn'] + cm['fp'])
def fnr(cm):
return cm['fn'] / (cm['fn'] + cm['tp'])
def recall(cm):
return cm['tp'] / (cm['tp'] + cm['fn'])
def precision(cm):
return cm['tp'] / (cm['tp'] + cm['fp'] + 1e-9) # Add epsilon to avoid NaNs.
def accuracy(cm):
return (cm['tp'] + cm['tn']) / (cm['tp'] + cm['fp'] + cm['tn'] + cm['fn'])
def f_beta(cm, beta=1):
return (1 + beta ** 2) * cm['tp'] / \
((1 + beta ** 2) * cm['tp'] + (beta ** 2) * cm['fn'] + cm['fp'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment