Skip to content

Instantly share code, notes, and snippets.

@SwadicalRag
Last active March 1, 2024 04:48
Show Gist options
  • Save SwadicalRag/6f966f172b87d7323cdee5bc9a719d2b to your computer and use it in GitHub Desktop.
Save SwadicalRag/6f966f172b87d7323cdee5bc9a719d2b to your computer and use it in GitHub Desktop.
AUROC / AUPRC / binary classifier statistics in typescript
import * as fs from "fs";
/**
* Class representing evaluation metrics for a binary classifier system.
* This class is designed to calculate and analyze the performance of a binary classifier
* as its discrimination threshold is varied, including ROC curve analysis and Precision-Recall curve analysis.
*/
export class BinaryClassifierStatistics {
/** Sorted array of unique thresholds from the scores in descending order */
thresholds: number[] = [];
/**
* Creates an instance of the ROC class.
* @param trueLabels - The array of true binary labels of the instances (1 for positive and 0 for negative).
* @param scores - The array of scores or probabilities as estimated by the model, corresponding to the true labels.
*/
constructor(public id: string, public trueLabels: (1 | 0)[] = [], public scores: number[] = []) {
this.validate();
this.recalculateThresholds();
}
/**
* Loads a `BinaryClassifierStatistics` instance from a specified file.
* This static method creates a new instance of the class, reads data from the given file path,
* deserializes the JSON content into the class properties, and returns the populated instance.
*
* @param path - The file path from which to load the serialized class instance data.
* @return A new instance of `BinaryClassifierStatistics` populated with the data from the file.
*/
static load(path: string) {
const res = new BinaryClassifierStatistics("");
res.load(path);
return res;
}
/**
* Loads data into the current instance from a specified file.
* This method reads data from the given file path, deserializes the JSON content,
* and updates the current instance's properties with the deserialized data.
*
* @param path - The file path from which to load the serialized data.
*/
load(path: string) {
this.deserialize(fs.readFileSync(path).toString());
}
/**
* Saves the current instance's data to a specified file.
* This method serializes the instance's properties (id, trueLabels, scores, thresholds)
* into JSON format and writes this data to the given file path, overwriting any existing content.
*
* @param path - The file path to which the serialized data will be saved.
*/
save(path: string) {
fs.writeFileSync(path, this.serialize());
}
/**
* Serializes the current instance's data into a JSON string.
* This method converts the instance's properties (id, trueLabels, scores, thresholds)
* into a JSON string format for easy storage or transmission.
*
* @return A JSON string representation of the instance's data.
*/
serialize() {
return JSON.stringify({
id: this.id,
trueLabels: this.trueLabels,
scores: this.scores,
thresholds: this.thresholds,
});
}
/**
* Deserializes data from a JSON string into the instance's properties.
* This method parses a JSON string to update the instance's properties (id, trueLabels, scores, thresholds)
* with the data from the JSON string, effectively loading the state from a serialized format.
*
* @param data - The JSON string from which to deserialize the data.
*/
deserialize(data: string) {
const deserialized = JSON.parse(data);
this.id = deserialized.id;
this.trueLabels = deserialized.trueLabels;
this.scores = deserialized.scores;
this.thresholds = deserialized.thresholds;
}
/**
* Adds a data point to the internal buffer
* @param trueLabel the known, true binary label of the data instance
* @param labelProbability the inferred probability of the data instance
*/
addData(trueLabel: boolean | 1 | 0, labelProbability: number) {
this.trueLabels.push(trueLabel ? 1 : 0);
this.scores.push(labelProbability);
this.recalculateThresholds();
}
/**
* Updates the data bufffers of an instance of the ROC class.
* @param trueLabels - The array of true binary labels of the instances (1 for positive and 0 for negative).
* @param scores - The array of scores or probabilities as estimated by the model, corresponding to the true labels.
*/
setData(trueLabels: (1 | 0)[], scores: number[]) {
this.trueLabels = trueLabels;
this.scores = scores;
this.validate();
this.recalculateThresholds();
}
validate() {
if(this.trueLabels.length !== this.scores.length) {
throw new Error("true label / prediction array length mismatch");
}
}
recalculateThresholds() {
this.thresholds = Array.from(new Set(this.scores)).sort((a, b) => b - a);
}
/**
* Calculates detailed statistics for each threshold
*
* This method iterates over each unique threshold to determine these statistics, which are
* crucial for evaluating the performance of a binary classification model.
*
* @returns An array of objects, each representing statistics at a specific threshold
*/
calculateStatistics() {
const results = this.thresholds.map(threshold => {
return {
/** threshold used to generate statistics */
threshold: threshold,
...this.calculateStatisticsAtThreshold(threshold),
};
});
return results;
}
/**
* Calculates detailed statistics for a specified threshold
*/
calculateStatisticsAtThreshold(threshold: number) {
let TP = 0, FP = 0, TN = 0, FN = 0;
for (let i = 0; i < this.scores.length; i++) {
if (this.scores[i] >= threshold) {
if (this.trueLabels[i] === 1) {
TP++;
} else {
FP++;
}
} else {
if (this.trueLabels[i] === 1) {
FN++;
} else {
TN++;
}
}
}
const TPR = TP + FN === 0 ? 0 : TP / (TP + FN);
const FPR = FP + TN === 0 ? 0 : FP / (FP + TN);
const FNR = TP + FN === 0 ? 0 : FN / (TP + FN);
const TNR = TN + FP === 0 ? 0 : TN / (TN + FP);
const PPV = TP + FP === 0 ? 0 : TP / (TP + FP);
const NPV = TN + FN === 0 ? 0 : TN / (FN + TN);
const FOR = FN + TN === 0 ? 0 : FN / (FN + TN);
const FDR = TP + FP === 0 ? 0 : FP / (TP + FP);
const Accuracy = TP + FN + FP + TN === 0 ? 0 : (TP + TN) / (TP + FN + FP + TN);
const BalancedAccuracy = (TPR + TNR) / 2;
const Informedness = TPR + TNR - 1;
const Markedness = PPV + NPV - 1;
const FM = PPV + TPR === 0 ? 0 : Math.sqrt(PPV * TPR);
const MCC = (TP + FN) * (TP + FP) * (TN + FP) * (TN + FN) === 0 ? 0 : (TP * TN - FP * FN) / Math.sqrt((TP + FN) * (TP + FP) * (TN + FP) * (TN + FN));
const PT = TPR === FPR ? 0 : (Math.sqrt(TPR * FPR) - FPR) / (TPR - FPR);
const PLR = FPR === 0 ? Infinity : TPR / FPR;
const NLR = TNR === 0 ? 0 : FNR / TNR;
const DOR = NLR === 0 ? Infinity : PLR / NLR;
const CSI = TP + FN + FP === 0 ? 0 : TP / (TP + FN + FP);
return {
/** True positives: The number of instances correctly identified as positive */
TP,
/** False positives: The number of instances incorrectly identified as positive */
FP,
/** True negatives: The number of instances correctly identified as negative */
TN,
/** False negatives: The number of instances incorrectly identified as negative */
FN,
/** The proportion of true results (both true positives and true negatives) among the
* total number of cases examined. It measures the overall correctness of the model. */
Accuracy,
/** The average of the proportion of true results in each class (sensitivity and
* specificity). It is particularly useful in situations where the classes are imbalanced. */
BalancedAccuracy,
/** True Positive Rate: Also known as sensitivity or recall, it measures the proportion
* of actual positives that are correctly identified. A higher TPR indicates a model's
* better performance in identifying positive cases. */
TPR,
/** False Positive Rate: It measures the proportion of actual negatives incorrectly identified
* as positives. */
FPR,
/** False Negative Rate: It measures the proportion of actual positives incorrectly identified
* as negatives. */
FNR,
/** True Negative Rate: Also known as specificity or selectivity, it measures the proportion
* of actual negatives that are correctly identified. A higher TNR indicates a model's
* better performance in identifying negative cases. */
TNR,
/** Positive Predictive Value: Also known as precision, it measures the proportion
* of positive identifications that were actually correct. A higher PPV indicates a model's
* better performance in predicting positive cases accurately. */
PPV,
/** Negative Predictive Value: It measures the proportion of negative identifications that were
* actually correct. */
NPV,
/** False Omission Rate: Measures the proportion of negative predictions that were incorrect */
FOR,
/** False Discovery Rate: The proportion of positive predictions that were incorrect. */
FDR,
/** Fowlkes-Mallows Index: A measure that combines precision and recall into
* a single metric. It is the geometric mean of precision (PPV) and recall (TPR).
*
* An FM score ranges from 0 to 1, where 1 indicates perfect precision and recall. A higher FM score
* suggests that the model effectively identifies positive instances and that the positive predictions
* it makes are reliable. */
FM,
/** Matthews Correlation Coefficient: Also known as the phi coefficient. A correlation coefficient between
* the observed and predicted classifications. It takes into account true and false positives and negatives
* and is considered a balanced measure that can be used even if the classes are of very different sizes.
*
* The MCC value ranges from -1 to 1. A coefficient of +1 represents a perfect prediction, 0 an average
* random prediction, and -1 an inverse prediction. This metric is particularly useful because it
* remains informative even when the dataset is imbalanced. */
MCC,
/** Prevalence Threshold: Refers to the point at which the prevalence of the condition being tested for
* makes the model's positive predictive value (PPV) equal to its sensitivity (TPR).
*
* PT provides insight into the effectiveness of a test or model across different prevalence rates,
* highlighting the importance of considering disease prevalence when evaluating test performance. */
PT,
/** Positive Likelihood Ratio: Indicates how much the odds of the disease increase when a test is positive. */
PLR,
/** Negative Likelihood Ratio: Indicates how much the odds of the disease decrease when a test is negative. */
NLR,
/** Diagnostic Odds Ratio): The ratio of the odds of the test being positive if the subject has a condition
* versus the odds of the test being positive if the subject does not have the condition. */
DOR,
/** Critical Success Index / Threat Score: Measures the proportion of correct positive predictions out of
* all instances that were predicted positive or were actually positive. It is similar to the F1 score
* but does not consider true negatives in its calculation.
*
* CSI values range from 0 to 1, where 1 indicates perfect performance in predicting positive instances.
* It is particularly useful where the focus is on correctly predicting rare events. */
CSI,
/** Also known as Youden's index. Measures the probability that a prediction is informed in relation to the
* actual class.
*
* It ranges from -1 to 1, where 1 indicates perfect knowledge (all predictions are correct), 0 indicates
* no better than random guessing, and -1 indicates total disagreement between prediction and actual class. */
Informedness,
/** Measures the probability that the actual class is correctly informed by the prediction.
*
* Like Informedness, it ranges from -1 to 1. A value of 1 indicates perfect marking (all
* actual classes are predicted correctly), 0 indicates no better than random marking, and
* -1 indicates complete misclassification. */
Markedness,
/** F1 Score: The harmonic mean of precision and recall, providing a single metric to assess
* the balance between them. A higher F1 score indicates a model's better overall performance
* in terms of both precision and recall. */
F1: (PPV + TPR) === 0 ? 0 : (2 * PPV * TPR) / (PPV + TPR),
};
}
/**
* Calculates detailed classifier statistics at threshold 0.5, alongside epidemiological statistics
*/
calculateSnapshotStatistics() {
return {
TotalPopulation: this.trueLabels.length,
Prevalance: this.trueLabels.filter(value => value === 1).length / this.trueLabels.length,
...this.calculateStatisticsAtThreshold(0.5),
};
}
/**
* Calculates the Area Under the Receiver Operating Characteristic Curve (AUROC).
* This method uses the trapezoidal rule to approximate the area under the ROC curve,
* which is a measure of the model's ability to discriminate between the positive and negative classes.
* @returns The calculated AUROC value.
*/
calculateAUROC() {
const points = this.calculateStatistics(); // Get the TPR and FPR points for the ROC curve
let auc = 0; // Initialize AUROC
for (let i = 0; i < points.length - 1; i++) {
const xDiff = points[i + 1].FPR - points[i].FPR; // Difference in FPR between consecutive points
const yAvg = (points[i].TPR + points[i + 1].TPR) / 2; // Average TPR of the two points
auc += xDiff * yAvg; // Increment AUROC using the trapezoidal rule
}
return auc; // Return the computed AUROC value
}
/**
* Calculates the Area Under the Precision-Recall Curve (AUPRC).
*
* The AUPRC is a valuable metric for evaluating the performance of a binary classifier, especially in datasets
* where the positive class is rare. This method approximates the AUPRC using the trapezoidal rule, based on
* the precision (positive predictive value) and recall (true positive rate) at various thresholds. It provides
* an aggregate measure of the model's ability to identify positive instances accurately across different
* threshold settings, emphasizing the balance between precision and recall in the presence of class imbalance.
*
* @returns The calculated AUPRC value, representing the model's average precision across all levels
* of recall. Higher AUPRC values indicate better model performance, particularly in its ability to prioritize
* the correct identification of positive instances while minimizing false positives.
*/
calculateAUPRC() {
const stats = this.calculateStatistics().sort((a, b) => a.TPR - b.TPR); // Ensure stats are sorted by recall
let auprc = 0; // Initialize AUPRC
for (let i = 0; i < stats.length - 1; i++) {
// Calculate the difference in recall between consecutive points
const recallDiff = stats[i + 1].TPR - stats[i].TPR;
// Calculate the average precision of the two points
const precisionAvg = (stats[i].PPV + stats[i + 1].PPV) / 2;
// Increment AUPRC using the trapezoidal rule
auprc += recallDiff * precisionAvg;
}
return auprc; // Return the computed AUPRC value
}
/**
* Calculates the optimal threshold based on maximizing the F1 score.
* The F1 score is the harmonic mean of precision and recall, providing a balance between the two.
* This method iterates over all possible thresholds to find the one with the highest F1 score.
*
* @returns An object containing the optimal threshold and the corresponding F1 score.
*/
calculateOptimalThresholdUsingF1Score() {
let optimalThreshold = 0;
let maxF1 = 0;
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined;
this.thresholds.forEach(threshold => {
const stats = this.calculateStatisticsAtThreshold(threshold);
const { F1 } = stats;
if (F1 > maxF1) {
maxF1 = F1;
optimalThreshold = threshold;
thresholdStats = stats;
}
});
return {
optimalThreshold: optimalThreshold,
maxF1: maxF1,
thresholdStats,
};
}
/**
* Calculates the optimal threshold based on maximizing Youden's index.
* Youden's index is defined as J = sensitivity + specificity - 1, which maximizes
* the classifier's performance by considering both true positive and true negative rates.
*
* @returns An object containing the optimal threshold and the corresponding Youden's index.
*/
calculateOptimalThresholdUsingYoudensIndex() {
let optimalThreshold = 0;
let maxYoudenIndex = -1; // Initialize with -1, the minimum possible value for J
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined;
this.thresholds.forEach(threshold => {
const stats = this.calculateStatisticsAtThreshold(threshold);
const { Informedness } = stats;
if (Informedness > maxYoudenIndex) {
maxYoudenIndex = Informedness;
optimalThreshold = threshold;
thresholdStats = stats;
}
});
return {
optimalThreshold: optimalThreshold,
maxYoudenIndex: maxYoudenIndex,
thresholdStats,
};
}
/**
* Calculates the optimal threshold based on maximizing the Matthews Correlation Coefficient (MCC).
* MCC is considered a balanced measure which can be used even if the classes are of very different sizes,
* ranging from -1 (total disagreement) to +1 (perfect prediction), with 0 indicating no better than random prediction.
*
* @returns An object containing the optimal threshold and the corresponding MCC.
*/
calculateOptimalThresholdUsingMCC() {
let optimalThreshold = 0;
let maxMCC = -1; // Initialize with -1, the minimum possible value for MCC
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined;
this.thresholds.forEach(threshold => {
const stats = this.calculateStatisticsAtThreshold(threshold);
const { MCC } = stats;
if (MCC > maxMCC) {
maxMCC = MCC;
optimalThreshold = threshold;
thresholdStats = stats;
}
});
return {
optimalThreshold: optimalThreshold,
maxMCC: maxMCC,
thresholdStats,
};
}
/**
* Calculates the optimal threshold for a specified metric and optimization goal.
* This method allows for flexible optimization based on a variety of metrics
* such as accuracy, precision, recall, F1 score, Matthews Correlation Coefficient (MCC), etc.
* It supports finding either the maximum or minimum value of the chosen metric across all possible thresholds,
* which can be useful for tailoring the performance of the binary classifier to specific operational requirements.
*
* The method iterates over all possible thresholds, evaluates the classifier's performance at each threshold using
* the specified metric, and identifies the threshold that optimizes (maximizes or minimizes) the metric's value.
* This approach enables the fine-tuning of the classifier's decision boundary for optimal performance on
* the given metric, which is particularly valuable in scenarios where trade-offs between different types
* of classification errors must be carefully managed.
*
* @param metric The name of the metric to optimize for. This should be a key from the object returned by
* calculateStatisticsAtThreshold, representing a specific performance metric of the classifier.
* @param optimisation Specifies the optimization goal: "maximum" to find the threshold that maximizes the metric,
* or "minimum" to find the threshold that minimizes the metric.
* @param initialisation (Optional) An initial value to start the optimization process. For maximum optimization,
* this could be the lowest possible value (e.g., -Infinity) to ensure any real value is higher.
* For minimum optimization, it could be the highest possible value (e.g., Infinity) to ensure
* any real value is lower. If not provided, defaults to -Infinity for maximum optimization
* and Infinity for minimum optimization.
*
* @returns An object containing the optimal threshold and the corresponding value of the optimized metric.
* The object has the structure: { optimalThreshold: number, value: number }, where optimalThreshold
* is the threshold that optimizes the specified metric, and value is the metric's optimized value.
*/
calculateOptimalThreshold(metric: keyof ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]>, optimisation: "maximum" | "minimum", initialisation?: number) {
let optimalThreshold = 0;
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined;
if(optimisation === "maximum") {
let maxValue = initialisation ?? (-1 / 0);
this.thresholds.forEach(threshold => {
const stats = this.calculateStatisticsAtThreshold(threshold);
const entry = stats[metric];
if (entry > maxValue) {
maxValue = entry;
optimalThreshold = threshold;
thresholdStats = stats;
}
});
return {
optimalThreshold: optimalThreshold,
value: maxValue,
thresholdStats,
};
}
else if(optimisation === "minimum") {
let minValue = initialisation ?? (1 / 0);
this.thresholds.forEach(threshold => {
const stats = this.calculateStatisticsAtThreshold(threshold);
const entry = stats[metric];
if (entry < minValue) {
minValue = entry;
optimalThreshold = threshold;
thresholdStats = stats;
}
});
return {
optimalThreshold: optimalThreshold,
value: minValue,
thresholdStats,
};
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment