Skip to content

Instantly share code, notes, and snippets.

@dhimmel
Last active August 29, 2015 14:21
Show Gist options
  • Save dhimmel/f0e966bc084a702f591f to your computer and use it in GitHub Desktop.
Save dhimmel/f0e966bc084a702f591f to your computer and use it in GitHub Desktop.
Variable threshold metrics in R for evaluating classifier performance
library(ROCR)
library(caTools)
VariableThresholdMetrics <- function(score, status) {
#' Evaluate the performance of predictions for a binary outcome.
#'
#' @param score A vector of predictions for which performance should be evalauted.
#' @param status A vector of the actual outcome: \code{0} (for negatives) and \code{1} (for positives).
#' @return A list.
# TPR is equivalent to recall
rocr.pred <- ROCR::prediction(score, status)
threshold.df <- data.frame(
'threshold'=rocr.pred@cutoffs[[1]],
'fpr'=ROCR::performance(rocr.pred, measure='fpr')@y.values[[1]],
'recall'=ROCR::performance(rocr.pred, measure='rec')@y.values[[1]],
'precision'=ROCR::performance(rocr.pred, measure='prec')@y.values[[1]],
'lift'=ROCR::performance(rocr.pred, measure='lift')@y.values[[1]]
)
auroc <- ROCR::performance(rocr.pred, 'auc')@y.values[[1]]
roc.df <- PruneROC(threshold.df[, c('fpr', 'recall')])
trapz.df <- na.omit(threshold.df[, c('recall', 'precision')])
auprc <- caTools::trapz(trapz.df$recall, trapz.df$precision)
# coefficient of discrimination
tjur <- mean(score[as.logical(status)]) - mean(score[! as.logical(status)])
metrics <- list('auroc'=auroc, 'auprc'=auprc, 'tjur'=tjur, 'threshold.df'=threshold.df, 'roc.df'=roc.df)
return(metrics)
}
PruneROC <- function(roc.df) {
#' Filter roc.df so each point on the ROC curve is distinct.
#' Helpful to reduce the file size of vector plots of the ROC curve.
stopifnot(all(c('fpr', 'recall') %in% colnames(roc.df)))
for (measure in c('fpr', 'recall')) {
not.dup <- ! duplicated(roc.df$recall)
not.dup <- not.dup | c(not.dup[-1], TRUE)
roc.df <- roc.df[not.dup, ]
}
return(roc.df)
}
PrunePRC <- function(prc.df, min.dist=0.0005) {
stopifnot(all(c('precision', 'recall') %in% colnames(prc.df)))
dist.df <- prc.df[, c('precision', 'recall')]
keep.row <- rowSums(is.na(dist.df)) == 0
prc.df <- prc.df[keep.row, ]
dist.df <- dist.df[keep.row, ]
pointer <- 1
as.index <- sapply(2:nrow(dist.df), function(i) {
distance <- dist(dist.df[c(pointer, i), 1:2])[1]
if (distance > min.dist) {
pointer <<- i
return(i)
} else {return(pointer)}
})
prc.df <- prc.df[c(1, unique(as.index)), ]
return(prc.df)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment