Skip to content

Instantly share code, notes, and snippets.

@zachmayer
Created July 6, 2012 16:46
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save zachmayer/3061272 to your computer and use it in GitHub Desktop.
Save zachmayer/3061272 to your computer and use it in GitHub Desktop.
Multi Class Error Metrics
#Multi-Class Summary Function
#Based on caret:::twoClassSummary
require(compiler)
multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){
#Load Libraries
require(Metrics)
require(caret)
#Check data
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
#Calculate custom one-vs-all stats for each class
prob_stats <- lapply(levels(data[, "pred"]), function(class){
#Grab one-vs-all data for the class
pred <- ifelse(data[, "pred"] == class, 1, 0)
obs <- ifelse(data[, "obs"] == class, 1, 0)
prob <- data[,class]
#Calculate one-vs-all AUC and logLoss and return
cap_prob <- pmin(pmax(prob, .000001), .999999)
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
names(prob_stats) <- c('ROC', 'logLoss')
return(prob_stats)
})
prob_stats <- do.call(rbind, prob_stats)
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
#Aggregate and average class-wise stats
#Todo: add weights
class_stats <- cbind(CM$byClass, prob_stats)
class_stats <- colMeans(class_stats)
#Aggregate overall stats
overall_stats <- c(CM$overall)
#Combine overall with class-wise stats and remove some stats we don't want
stats <- c(overall_stats, class_stats)
stats <- stats[! names(stats) %in% c('AccuracyNull',
'Prevalence', 'Detection Prevalence')]
#Clean names and return
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
return(stats)
})
#CLEAR WORKSPACE
rm(list = ls(all = TRUE))
gc(reset=TRUE)
#Setup parallel cluster
#If running on the command line of linux, use method='fork'
library(doParallel)
cl <- makeCluster(detectCores(), type='PSOCK')
registerDoParallel(cl)
#Fit model
library(caret)
set.seed(19556)
model <- train(
Species~.,
data=iris,
method='knn',
tuneGrid=expand.grid(.k=1:30),
metric='Accuracy',
trControl=trainControl(
method='repeatedcv',
number=10,
repeats=15,
classProbs=TRUE,
summaryFunction=multiClassSummary))
#Stop parallel cluster
stopCluster(cl)
#Save pdf of plots
dev.off()
pdf('plots.pdf')
for(stat in c('Accuracy', 'Kappa', 'AccuracyLower', 'AccuracyUpper', 'AccuracyPValue',
'Sensitivity', 'Specificity', 'Pos_Pred_Value',
'Neg_Pred_Value', 'Detection_Rate', 'ROC', 'logLoss')) {
print(plot(model, metric=stat))
}
dev.off()
@alvesvm
Copy link

alvesvm commented Jun 25, 2014

Hi Zack,

Could you please clarify how this aggregation of stats work for the AUC?

Thank you.

@zachmayer
Copy link
Author

@alvesvm Better late than never, I suppose. I just average the pairwise AUCs, as advocated by Hand and Till (2001)

@charly06
Copy link

charly06 commented May 22, 2017

@zachmayer, I needed exactly such an algorithm as you published. However, after reading Hand and Till 2001 and several other papers I came to the conclusion, that the way you calculate the overall AUC is not really the one proposed by Hand and Till (2001).
With colMeans(class_stats) you do something different than they propose. You might a look at Huang and Ling (2005) on page 305. They provide a nice example on how to calculate the overall AUC for multiclass datasets. I forked your gist and adapted it correspondingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment