Skip to content

Instantly share code, notes, and snippets.

@zmjones
Last active August 29, 2015 14:25
Show Gist options
  • Save zmjones/332fe7a8d9a760782f01 to your computer and use it in GitHub Desktop.
Save zmjones/332fe7a8d9a760782f01 to your computer and use it in GitHub Desktop.
calibration plots for MLR
lrns = list(makeLearner("classif.rpart", predict.type = "prob"),
makeLearner("classif.nnet", predict.type = "prob"))
fit = lapply(lrns, train, task = iris.task)
pred = lapply(fit, predict, task = iris.task)
names(pred) = c("rpart", "nnet")
out = generateCalibrationData(pred)
plotCalibration(out)
fit = lapply(lrns, train, task = sonar.task)
pred = lapply(fit, predict, task = sonar.task)
names(pred) = c("rpart", "lda")
out = generateCalibrationData(pred)
plotCalibration(out, TRUE)
#' @export
generateCalibrationData = function(obj, breaks = "Sturges") UseMethod("generateCalibrationData")
#' @export
generateCalibrationData.Prediction = function(obj, breaks = "Sturges") {
checkPrediction(obj, task.type = "classif", predict.type = "prob")
generateCalibrationData.list(namedList("prediction", obj), breaks)
}
#' @export
generateCalibrationData.ResampleResult = function(obj, breaks = "Sturges") {
obj = getRRPredictions(obj)
checkPrediction(obj, task.type = "classif", predict.type = "prob")
generateCalibrationData.Prediction(obj, breaks)
}
#' @export
generateCalibrationData.BenchmarkResult = function(obj, breaks = "Sturges") {
tids = getBMRTaskIds(obj)
if (is.null(task.id))
task.id = tids[1L]
else
assertChoice(task.id, tids)
obj = getBMRPredictions(obj, task.ids = task.id, as.df = FALSE)[[1L]]
for (x in obj)
checkPrediction(x, task.type = "classif", predict.type = "prob")
generateCalibrationData.list(obj, breaks)
}
#' @export
generateCalibrationData.list = function(obj, breaks = "Sturges") {
assertList(obj, c("Prediction", "ResampleResult"), min.len = 1L)
## unwrap ResampleResult to Prediction and set default names
if (inherits(obj[[1L]], "ResampleResult")) {
if (is.null(names(obj)))
names(obj) = extractSubList(obj, "learner.id")
obj = extractSubList(obj, "pred", simplify = FALSE)
}
assertList(obj, names = "unique")
td = obj[[1L]]$task.desc
out = lapply(obj, function(pred) {
df = data.frame("truth" = getPredictionTruth(pred),
getPredictionProbabilities(pred, cl = getTaskClassLevels(td)))
df = reshape2::melt(df, id.vars = "truth", value.name = "Probability", variable.name = "Class")
break.points = hist(df$Probability, breaks = breaks, plot = FALSE)$breaks
df$bin = cut(df$Probability, break.points, include.lowest = TRUE, ordered_results = TRUE)
plyr::ddply(df, "bin", function(x) {
tab = table(x$Class, x$truth)
s = rowSums(tab)
ifelse(s == 0, 0, diag(tab) / s)
})
})
names(out) = names(obj)
out = plyr::ldply(out, .id = "Learner")
if (length(td$class.levels) == 2L)
out = out[, -which(colnames(out) == td$negative)]
out$bin = as.factor(out$bin)
max_bin = sapply(strsplit(levels(out$bin), ",|]"), function(x) as.numeric(x[2]))
out$bin = ordered(out$bin, levels = levels(out$bin)[order(max_bin)])
out = reshape2::melt(out, id.vars = c("Learner", "bin"), value.name = "Proportion", variable.name = "Class")
makeS3Obj("CalibrationData",
data = out,
task = td)
}
#' @export
plotCalibration = function(obj, smooth = FALSE) {
assertClass(obj, "CalibrationData")
p = ggplot(obj$data, aes_string("bin", "Proportion", color = "Class", group = "Class"))
if (smooth)
p = p + stat_smooth(se = FALSE, span = 2, method = "loess")
else
p = p + geom_point() + geom_line()
if (length(obj$data$Learner) > 1L)
p = p + facet_wrap(~ Learner)
p = p + labs(x = "Probability Bin", y = "Class Proportion")
p + theme(axis.text.x = element_text(angle = 90, hjust = 1))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment