-
-
Save charly06/52b4c9300bba93fca3722b0a76cd358a to your computer and use it in GitHub Desktop.
Multi Class Performance Metrics with AUC based on Hand and Till (2001) following the description of Huang and Ling (2005). Use stratified folds, otherwise you might get errors in case one class is missing in a fold.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Multi-Class Summary Function based on Hand and Till (2001) | |
#Based on caret:::twoClassSummary | |
require(compiler) | |
multiClassSummaryAUC <- cmpfun(function (data, lev = NULL, model = NULL){ | |
#Load Libraries | |
require(Metrics) | |
require(caret) | |
#print(data) | |
allObs <- data[, "obs"] | |
obsLevels <- levels(allObs) | |
#Check data | |
if (!all(levels(data[, "pred"]) == obsLevels)) | |
stop("levels of observed and predicted data do not match") | |
#Remove levels without observations | |
#obsLevels <- levels(droplevels(allObs)) | |
#Initialize variable | |
aucs <- c() | |
#Calculate pair-wise AUCs as described in Huang and Ling 2005, p. 305 | |
#based on Hand and Till 2001 | |
for(iClass in 1:length(obsLevels)) { | |
class <- obsLevels[iClass] | |
for (iInnerClass in 1:length(obsLevels)) { | |
if (iInnerClass > iClass) { | |
#Since AUCs are calculated pair-wise in every iteration, only the remaining classes | |
#which have not been compared yet are relevant | |
innerClass <- obsLevels[iInnerClass] | |
#In order to calculate pair-wise AUCs, use only observations with labels class and innerClass | |
classFilter <- allObs == class | allObs == innerClass | |
obsClass <- ifelse(allObs[classFilter] == class, 1, 0) | |
probClass <- data[classFilter, as.character(class)] | |
aucClass <- Metrics::auc(obsClass, probClass) | |
obsInner <- ifelse(allObs[classFilter] == innerClass, 1, 0) | |
probInner <- data[classFilter, as.character(innerClass)] | |
aucInner <- Metrics::auc(obsInner, probInner) | |
#print(paste("AUCs: ", aucClass, ", ", aucInner)) | |
aucs <- c(aucs, (aucClass + aucInner) / 2) | |
} | |
} | |
} | |
#Average AUC values for every pair-wise AUC value | |
finalAUC <- sum(aucs) * 2 / (length(obsLevels) * (length(obsLevels) - 1)) | |
names(finalAUC) <- 'ROC' | |
#Calculate confusion matrix-based statistics | |
CM <- confusionMatrix(data[, "pred"], data[, "obs"]) | |
class_stats <- colMeans(CM$byClass) | |
#Aggregate overall stats | |
overall_stats <- c(CM$overall) | |
#Combine overall with AUC and remove some stats we don't want | |
stats <- c(overall_stats, class_stats, finalAUC) | |
stats <- stats[! names(stats) %in% c('AccuracyNull', | |
'Prevalence', 'Detection Prevalence')] | |
#Clean names and return | |
names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) | |
return(stats) | |
}) | |
#Common implemtation which does NOT follow the suggestions of Hand and Till (2001) | |
require(compiler) | |
multiClassSummaryOld <- cmpfun(function (data, lev = NULL, model = NULL){ | |
#Load Libraries | |
require(Metrics) | |
require(caret) | |
#Check data | |
if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) | |
stop("levels of observed and predicted data do not match") | |
#Calculate custom one-vs-all stats for each class | |
prob_stats <- lapply(levels(data[, "obs"]), function(class){ | |
#Grab one-vs-all data for the class | |
pred <- ifelse(data[, "pred"] == class, 1, 0) | |
obs <- ifelse(data[, "obs"] == class, 1, 0) | |
prob <- data[, as.character(class)] | |
#Calculate one-vs-all AUC and logLoss and return | |
cap_prob <- pmin(pmax(prob, .000001), .999999) | |
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob)) | |
names(prob_stats) <- c('ROC', 'logLoss') | |
return(prob_stats) | |
}) | |
prob_stats <- do.call(rbind, prob_stats) | |
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"])) | |
#Calculate confusion matrix-based statistics | |
CM <- confusionMatrix(data[, "pred"], data[, "obs"]) | |
#Todo: Aggregate and average class-wise stats as suggested by Hand and Till (2001) | |
class_stats <- cbind(CM$byClass, prob_stats) | |
class_stats <- colMeans(class_stats, na.rm = T) | |
#Aggregate overall stats | |
overall_stats <- c(CM$overall) | |
#Combine overall with class-wise stats and remove some stats we don't want | |
stats <- c(overall_stats, class_stats) | |
stats <- stats[! names(stats) %in% c('AccuracyNull', | |
'Prevalence', 'Detection Prevalence')] | |
#Clean names and return | |
names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) | |
return(stats) | |
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Example used in Huang and Ling (2005), p. 305 | |
exampleData <- data.frame(pred=c("c1", "c3", "c2", "c1", "c3", "c1"), | |
obs=c("c1", "c1", "c2", "c2", "c3", "c3"), | |
c1=c(0.6, 0.15, 0.3, 0.45, 0.1, 0.8), | |
c2=c(0.15, 0.3, 0.5, 0.25, 0.2, 0.05), | |
c3=c(0.25, 0.55, 0.2, 0.3, 0.7, 0.15), | |
stringsAsFactors = T) | |
multiClassSummaryAUC(exampleData) | |
multiClassSummaryOld(exampleData) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Create dummy data | |
set.seed(500) | |
n <- 300 | |
MyData <- data.frame(f1=runif(n, max = 50), | |
f2=runif(n, max = 50), | |
f3=runif(n, min = 50, max = 100), | |
f4=runif(n, max = 10), | |
myClass=sample(c("a", "b", "c", "d", "e"), size = n, replace = T)) | |
#Create stratified folds | |
folds <- 10 | |
cvIndex <- createFolds(factor(MyData$myClass), folds, returnTrain = T) | |
#Fit model | |
library(caret) | |
set.seed(20000) | |
model <- train( | |
myClass~., | |
data=MyData, | |
method='knn', | |
tuneGrid=expand.grid(.k=c(1, 5, 10, 15)), | |
metric='ROC', | |
trControl=trainControl( | |
index = cvIndex, | |
method='cv', | |
number=folds, | |
classProbs=TRUE, | |
summaryFunction=multiClassSummaryAUC)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment