Skip to content

Instantly share code, notes, and snippets.

@markgrujic
Created December 19, 2018 22:38
Show Gist options
  • Save markgrujic/a46a4466b618164e77e0abda14d97590 to your computer and use it in GitHub Desktop.
Save markgrujic/a46a4466b618164e77e0abda14d97590 to your computer and use it in GitHub Desktop.
missRanger returning OOB error
function (data, maxiter = 10L, pmm.k = 0L, seed = NULL, verbose = 1, returnOOB = F, ...){
if (verbose > 0) {
cat("\nMissing value imputation by chained tree ensembles\n")
}
stopifnot(is.data.frame(data), dim(data) >= 1L, is.numeric(maxiter),
length(maxiter) == 1L, maxiter >= 1L, is.numeric(pmm.k),
length(pmm.k) == 1L, pmm.k >= 0L, !(c("formula", "data",
"write.forest", "probability", "split.select.weights",
"dependent.variable.name", "classification") %in%
names(list(...))))
if (!is.null(seed)) {
set.seed(seed)
}
allVars <- names(which(vapply(data, function(z) (is.factor(z) ||
is.numeric(z)) && any(!is.na(z)), TRUE)))
if (verbose > 0 && length(allVars) < ncol(data)) {
cat("\n Variables ignored in imputation (wrong data type or all values missing: ")
cat(setdiff(names(data), allVars), sep = ", ")
}
stopifnot(length(allVars) > 1L)
data.na <- is.na(data[, allVars, drop = FALSE])
count.seq <- sort(colMeans(data.na))
visit.seq <- names(count.seq)[count.seq > 0]
if (!length(visit.seq)) {
return(data)
}
verboseDigits <- 4
j <- 1L
predError <- rep(1, length(visit.seq))
names(predError) <- visit.seq
crit <- TRUE
completed <- setdiff(allVars, visit.seq)
if (verbose >= 2) {
cat("\n", abbreviate(visit.seq, minlength = verboseDigits +
2), sep = "\t")
}
while (crit && j <= maxiter) {
if (verbose > 0) {
cat("\niter ", j, ":\t", sep = "")
}
data.last <- data
predErrorLast <- predError
for (v in visit.seq) {
v.na <- data.na[, v]
if (length(completed) == 0L) {
data[[v]] <- imputeUnivariate(data[[v]])
}
else {
fit <- ranger(formula = reformulate(completed,
response = v), data = data[!v.na, union(v,
completed)], ...)
pred <- predict(fit, data[v.na, allVars])$predictions
data[v.na, v] <- if (pmm.k)
pmm(xtrain = fit$predictions, xtest = pred,
ytrain = data[[v]][!v.na], k = pmm.k)
else pred
predError[[v]] <- fit$prediction.error/(if (fit$treetype ==
"Regression")
var(data[[v]][!v.na])
else 1)
if (is.nan(predError[[v]])) {
predError[[v]] <- 0
}
}
completed <- union(completed, v)
if (verbose == 1) {
cat(".")
}
else if (verbose >= 2) {
cat(format(round(predError[[v]], verboseDigits),
nsmall = verboseDigits), "\t")
}
}
j <- j + 1L
crit <- mean(predError) < mean(predErrorLast)
}
if (verbose > 0) {
cat("\n")
}
if (j == 2L || (j == maxiter && crit)){
if(returnOOB) {list(ximp = data, oob = predError)} else {data}
} else {
if(returnOOB) {list(ximp = data.last, oob = predError)} else {data.last}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment