Skip to content

Instantly share code, notes, and snippets.

@charly06
Forked from zachmayer/multiclass.R
Last active January 5, 2019 12:20
Show Gist options
  • Save charly06/52b4c9300bba93fca3722b0a76cd358a to your computer and use it in GitHub Desktop.
Save charly06/52b4c9300bba93fca3722b0a76cd358a to your computer and use it in GitHub Desktop.
Multi Class Performance Metrics with AUC based on Hand and Till (2001) following the description of Huang and Ling (2005). Use stratified folds, otherwise you might get errors in case one class is missing in a fold.
#Multi-Class Summary Function based on Hand and Till (2001)
#Based on caret:::twoClassSummary
require(compiler)
multiClassSummaryAUC <- cmpfun(function (data, lev = NULL, model = NULL){
#Load Libraries
require(Metrics)
require(caret)
#print(data)
allObs <- data[, "obs"]
obsLevels <- levels(allObs)
#Check data
if (!all(levels(data[, "pred"]) == obsLevels))
stop("levels of observed and predicted data do not match")
#Remove levels without observations
#obsLevels <- levels(droplevels(allObs))
#Initialize variable
aucs <- c()
#Calculate pair-wise AUCs as described in Huang and Ling 2005, p. 305
#based on Hand and Till 2001
for(iClass in 1:length(obsLevels)) {
class <- obsLevels[iClass]
for (iInnerClass in 1:length(obsLevels)) {
if (iInnerClass > iClass) {
#Since AUCs are calculated pair-wise in every iteration, only the remaining classes
#which have not been compared yet are relevant
innerClass <- obsLevels[iInnerClass]
#In order to calculate pair-wise AUCs, use only observations with labels class and innerClass
classFilter <- allObs == class | allObs == innerClass
obsClass <- ifelse(allObs[classFilter] == class, 1, 0)
probClass <- data[classFilter, as.character(class)]
aucClass <- Metrics::auc(obsClass, probClass)
obsInner <- ifelse(allObs[classFilter] == innerClass, 1, 0)
probInner <- data[classFilter, as.character(innerClass)]
aucInner <- Metrics::auc(obsInner, probInner)
#print(paste("AUCs: ", aucClass, ", ", aucInner))
aucs <- c(aucs, (aucClass + aucInner) / 2)
}
}
}
#Average AUC values for every pair-wise AUC value
finalAUC <- sum(aucs) * 2 / (length(obsLevels) * (length(obsLevels) - 1))
names(finalAUC) <- 'ROC'
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
class_stats <- colMeans(CM$byClass)
#Aggregate overall stats
overall_stats <- c(CM$overall)
#Combine overall with AUC and remove some stats we don't want
stats <- c(overall_stats, class_stats, finalAUC)
stats <- stats[! names(stats) %in% c('AccuracyNull',
'Prevalence', 'Detection Prevalence')]
#Clean names and return
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
return(stats)
})
#Common implemtation which does NOT follow the suggestions of Hand and Till (2001)
require(compiler)
multiClassSummaryOld <- cmpfun(function (data, lev = NULL, model = NULL){
#Load Libraries
require(Metrics)
require(caret)
#Check data
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
#Calculate custom one-vs-all stats for each class
prob_stats <- lapply(levels(data[, "obs"]), function(class){
#Grab one-vs-all data for the class
pred <- ifelse(data[, "pred"] == class, 1, 0)
obs <- ifelse(data[, "obs"] == class, 1, 0)
prob <- data[, as.character(class)]
#Calculate one-vs-all AUC and logLoss and return
cap_prob <- pmin(pmax(prob, .000001), .999999)
prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
names(prob_stats) <- c('ROC', 'logLoss')
return(prob_stats)
})
prob_stats <- do.call(rbind, prob_stats)
rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))
#Calculate confusion matrix-based statistics
CM <- confusionMatrix(data[, "pred"], data[, "obs"])
#Todo: Aggregate and average class-wise stats as suggested by Hand and Till (2001)
class_stats <- cbind(CM$byClass, prob_stats)
class_stats <- colMeans(class_stats, na.rm = T)
#Aggregate overall stats
overall_stats <- c(CM$overall)
#Combine overall with class-wise stats and remove some stats we don't want
stats <- c(overall_stats, class_stats)
stats <- stats[! names(stats) %in% c('AccuracyNull',
'Prevalence', 'Detection Prevalence')]
#Clean names and return
names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
return(stats)
})
#Example used in Huang and Ling (2005), p. 305
exampleData <- data.frame(pred=c("c1", "c3", "c2", "c1", "c3", "c1"),
obs=c("c1", "c1", "c2", "c2", "c3", "c3"),
c1=c(0.6, 0.15, 0.3, 0.45, 0.1, 0.8),
c2=c(0.15, 0.3, 0.5, 0.25, 0.2, 0.05),
c3=c(0.25, 0.55, 0.2, 0.3, 0.7, 0.15),
stringsAsFactors = T)
multiClassSummaryAUC(exampleData)
multiClassSummaryOld(exampleData)
#Create dummy data
set.seed(500)
n <- 300
MyData <- data.frame(f1=runif(n, max = 50),
f2=runif(n, max = 50),
f3=runif(n, min = 50, max = 100),
f4=runif(n, max = 10),
myClass=sample(c("a", "b", "c", "d", "e"), size = n, replace = T))
#Create stratified folds
folds <- 10
cvIndex <- createFolds(factor(MyData$myClass), folds, returnTrain = T)
#Fit model
library(caret)
set.seed(20000)
model <- train(
myClass~.,
data=MyData,
method='knn',
tuneGrid=expand.grid(.k=c(1, 5, 10, 15)),
metric='ROC',
trControl=trainControl(
index = cvIndex,
method='cv',
number=folds,
classProbs=TRUE,
summaryFunction=multiClassSummaryAUC))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment