Skip to content

Instantly share code, notes, and snippets.

@benjaminrich
Created September 25, 2023 00:58
Show Gist options
  • Save benjaminrich/a0b5b1e6cbd269678cd5e90a90268aa6 to your computer and use it in GitHub Desktop.
Save benjaminrich/a0b5b1e6cbd269678cd5e90a90268aa6 to your computer and use it in GitHub Desktop.
Stepwise regression
stepwise_forward <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("stepwise_forward")
stepwise_backward <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("stepwise_backward")
forward_step <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("forward_step")
backward_step <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("backward_step")
fit_all_models <- function(base_fit, all_formulas, ...) UseMethod("fit_all_models")
model_table <- function(obj, ...) UseMethod("model_table")
pvalue <- function(obj, ...) UseMethod("pvalue")
selected <- function(obj, ...) UseMethod("selected")
final_model <- function(obj, ...) UseMethod("final_model")
pvalue.default <- function(obj, ...) attr(obj, "pvalue", exact=TRUE)
selected.default <- function(obj, ...) attr(obj, "selected", exact=TRUE)
final_model.default <- function(obj, ...) attr(obj, "final_model", exact=TRUE)
stepwise_search <- function(base_fit, candidates, alpha, step_fn) {
best_fit <- base_fit
res <- list()
while (TRUE) {
if (length(candidates) == 0) break
step <- step_fn(best_fit, candidates, alpha)
res <- c(res, list(step))
if (is.null(selected(step))) break
best_fit <- final_model(step)
candidates <- setdiff(candidates, selected(step))
}
structure(list(res),
class = "stepwise_search",
selected = unlist(lapply(res, selected)),
final_model = best_fit
)
}
stepwise_forward.default <- function(base_fit, candidates, alpha=0.05) {
res <- stepwise_search(
base_fit = base_fit,
candidates = candidates,
alpha = alpha,
step_fn = forward_step
)
structure(setNames(res, "forward"),
class = "stepwise_forward",
selected = list(forward=selected(res)))
}
stepwise_backward.default <- function(base_fit, candidates, alpha=0.01) {
res <- stepwise_search(
base_fit = base_fit,
candidates = candidates,
alpha = alpha,
step_fn = backward_step
)
structure(setNames(res, "backward"),
class = "stepwise_backward",
selected = list(backward=selected(res)))
}
stepwise_backward.stepwise_forward <- function(base_fit, candidates, alpha=0.01) {
if (missing(candidates)) {
candidates <- unlist(selected(base_fit), use.names=F)
}
back_fit <- stepwise_backward(final_model(base_fit), candidates, alpha)
structure(c(base_fit, back_fit),
class = "stepwise_forward_backward",
selected = c(selected(base_fit), selected(back_fit)),
final_model = final_model(back_fit)
)
}
generic_step <- function(
base_fit,
candidates,
alpha,
direction = c("forward", "backward"),
op = if (direction=="forward") `<` else `>=`,
...
) {
direction <- match.arg(direction)
all_formulas <- derive_all_formulas(base_fit, base_formula=formula(base_fit),
add = if (direction=="forward") candidates else NULL,
subtract = if (direction=="backward") candidates else NULL
)
all_fits <- fit_all_models(base_fit, all_formulas, data=base_fit$data, ...)
mtab <- model_table(all_fits, base_fit, direction=direction, sort=TRUE)
pval <- pvalue(mtab)
i <- if (direction=="forward") which.min(pval) else which.max(pval)
if (op(pval[i], alpha)) {
selected <- names(pval)[i]
final_model <- all_fits[[selected]]
} else {
selected <- NULL
final_model <- base_fit
}
structure(mtab,
class = class(mtab),
base_fit = base_fit,
all_fits = all_fits,
selected = selected,
final_model = final_model)
}
forward_step.default <- function(base_fit, candidates, alpha=0.05, ...) {
res <- generic_step(
base_fit = base_fit,
candidates = candidates,
alpha = alpha,
direction = "forward",
...
)
structure(res, class=c("forward_step", class(res)))
}
backward_step.default <- function(base_fit, candidates, alpha=0.01, ...) {
res <- generic_step(
base_fit = base_fit,
candidates = candidates,
alpha = alpha,
direction = "backward",
...
)
structure(res, class=c("backward_step", class(res)))
}
get_names <- function(...) {
`%||%` <- function(a, b) if (is.null(a)) b else a
lapply(list(...), function(x) names(x) %||% as.character(x))
}
derive_all_formulas <- function(
base_fit,
base_formula = formula(base_fit),
add = NULL,
subtract = NULL,
formula_names = unlist(get_names(add, subtract))
) {
.add <- if (!is.null(add)) paste0("+", add) else NULL
.subtract <- if (!is.null(subtract)) paste0("-", subtract) else NULL
paste0(".~.", c(.add, .subtract)) |>
lapply(update.formula, old=base_formula) |>
setNames(formula_names)
}
fit_all_models.default <- function(
base_fit,
all_formulas,
data = base_fit$data,
model_names = names(all_formulas),
...
) {
lapply(all_formulas, function(x) {
update(base_fit, formula.=x, data=data, ...)
}) |> setNames(model_names)
}
model_table.default <- function(
all_fits,
base_fit,
alpha,
direction = c("forward", "backward"),
sort = TRUE,
decreasing = FALSE,
...
) {
f <- function(x) ifelse(direction=="forward", x, -x)
mtab <- lapply(all_fits, function(x) {
`-2*loglik` <- -2*as.numeric(logLik(x))
`df` <- attr(logLik(x), "df", exact=TRUE)
`Base(-2*loglik)` <- -2*as.numeric(logLik(base_fit))
`Base(df)` <- attr(logLik(base_fit), "df", exact=TRUE)
`Δ(-2*loglik)` <- f(`Base(-2*loglik)` - `-2*loglik`)
`Δdf` <- f(`df` - `Base(df)`)
`P-value` <- pchisq(`Δ(-2*loglik)`, `Δdf`, lower.tail=FALSE)
data.frame(check.names=FALSE,
`Model` = NA,
`-2*loglik`,
`df`,
`Base(-2*loglik)`,
`Base(df)`,
`Δ(-2*loglik)`,
`Δdf`,
`P-value`
)
}) |> do.call(what=rbind)
mtab$`Model` <- names(all_fits)
if (sort) {
mtab <- mtab[order(mtab$`P-value`, decreasing=decreasing),]
}
structure(mtab,
class = c("model_table", class(mtab)),
all_fits = all_fits,
base_fit = base_fit,
pvalue = setNames(mtab$`P-value`, mtab$`Model`)
)
}
if (FALSE) {
library(mvtnorm)
set.seed(123)
n <- 100
p <- 4
S <- rWishart(1, p, toeplitz(1/(1:p)))[,,1]
x <- rmvnorm(n, rep(0, p), S)
dat <- data.frame(
x1 = x[,1],
x2 = x[,2],
x3 = x[,3],
x4 = x[,4],
x5 = rnorm(n),
x6 = rnorm(n),
x7 = rnorm(n),
x8 = rnorm(n),
x9 = rnorm(n)
)
dat$y <- with(dat, 6 + 0.3*x1 + 0.1*x2 + 0.4*x3 + 0.2*x4 + rnorm(n, 0, 1.3))
base_fit <- glm(y ~ x1, data=dat)
candidates <- c(
"x2",
"x3",
"x4",
"x5",
"x6",
"x7",
"x8",
"x9"
)
x <- forward_step(base_fit, candidates, alpha=0.05)
x
final_model(x)
x <- backward_step(base_fit, "x1", alpha=0.01)
x
final_model(x)
x <- stepwise_forward(base_fit, candidates, alpha=0.05)
x
selected(x)
final_model(x)
y <- stepwise_backward(base_fit, c("x3"), alpha=0.01)
y
selected(y)
final_model(y)
y <- stepwise_backward(x, alpha=0.01)
y
selected(y)
final_model(y)
x <- base_fit |>
stepwise_forward(candidates, alpha=0.05) |>
stepwise_backward(alpha=0.01)
x
selected(x)
final_model(x)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment