Last active
August 29, 2015 14:21
-
-
Save dhimmel/f0e966bc084a702f591f to your computer and use it in GitHub Desktop.
Variable threshold metrics in R for evaluating classifier performance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(ROCR) | |
library(caTools) | |
VariableThresholdMetrics <- function(score, status) { | |
#' Evaluate the performance of predictions for a binary outcome. | |
#' | |
#' @param score A vector of predictions for which performance should be evalauted. | |
#' @param status A vector of the actual outcome: \code{0} (for negatives) and \code{1} (for positives). | |
#' @return A list. | |
# TPR is equivalent to recall | |
rocr.pred <- ROCR::prediction(score, status) | |
threshold.df <- data.frame( | |
'threshold'=rocr.pred@cutoffs[[1]], | |
'fpr'=ROCR::performance(rocr.pred, measure='fpr')@y.values[[1]], | |
'recall'=ROCR::performance(rocr.pred, measure='rec')@y.values[[1]], | |
'precision'=ROCR::performance(rocr.pred, measure='prec')@y.values[[1]], | |
'lift'=ROCR::performance(rocr.pred, measure='lift')@y.values[[1]] | |
) | |
auroc <- ROCR::performance(rocr.pred, 'auc')@y.values[[1]] | |
roc.df <- PruneROC(threshold.df[, c('fpr', 'recall')]) | |
trapz.df <- na.omit(threshold.df[, c('recall', 'precision')]) | |
auprc <- caTools::trapz(trapz.df$recall, trapz.df$precision) | |
# coefficient of discrimination | |
tjur <- mean(score[as.logical(status)]) - mean(score[! as.logical(status)]) | |
metrics <- list('auroc'=auroc, 'auprc'=auprc, 'tjur'=tjur, 'threshold.df'=threshold.df, 'roc.df'=roc.df) | |
return(metrics) | |
} | |
PruneROC <- function(roc.df) { | |
#' Filter roc.df so each point on the ROC curve is distinct. | |
#' Helpful to reduce the file size of vector plots of the ROC curve. | |
stopifnot(all(c('fpr', 'recall') %in% colnames(roc.df))) | |
for (measure in c('fpr', 'recall')) { | |
not.dup <- ! duplicated(roc.df$recall) | |
not.dup <- not.dup | c(not.dup[-1], TRUE) | |
roc.df <- roc.df[not.dup, ] | |
} | |
return(roc.df) | |
} | |
PrunePRC <- function(prc.df, min.dist=0.0005) { | |
stopifnot(all(c('precision', 'recall') %in% colnames(prc.df))) | |
dist.df <- prc.df[, c('precision', 'recall')] | |
keep.row <- rowSums(is.na(dist.df)) == 0 | |
prc.df <- prc.df[keep.row, ] | |
dist.df <- dist.df[keep.row, ] | |
pointer <- 1 | |
as.index <- sapply(2:nrow(dist.df), function(i) { | |
distance <- dist(dist.df[c(pointer, i), 1:2])[1] | |
if (distance > min.dist) { | |
pointer <<- i | |
return(i) | |
} else {return(pointer)} | |
}) | |
prc.df <- prc.df[c(1, unique(as.index)), ] | |
return(prc.df) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment