Skip to content

Instantly share code, notes, and snippets.

@dewittpe
Last active August 29, 2015 14:05
Show Gist options
  • Save dewittpe/8a27e686ab2159e1ee77 to your computer and use it in GitHub Desktop.
Save dewittpe/8a27e686ab2159e1ee77 to your computer and use it in GitHub Desktop.
# My answer to http://stackoverflow.com/questions/25272387
library(ggplot2)
library(reshape2)
library(dplyr)
library(magrittr)
library(splines)
set.seed(42)
# Define a function with_new_knots for adding new knots to the splines within a formula call.
with_new_knots <- function(frm, data, iterations = 5L) {
# extract the original formula
old_terms <- terms(frm, specials = c("bs", "ns"))
# reconstruct the rhs of the formula with any interaction terms expanded
cln <- colnames(attr(old_terms, "factors"))
old_rhs <- paste(cln, collapse = " + ")
# Extract the spline terms from the old_formula
idx <- attr(old_terms, "specials") %>% unlist %>% sort
old_spline_terms <- attr(old_terms, "factors") %>% rownames %>% extract(idx)
# grab the variable names which splines are built on
vars <- all.vars(frm)[idx]
# define the range for each variable in vars
rngs <- lapply(vars, function(x) { range(data[, x]) })
# for each of the spline terms, randomly generate new knots
# This is a silly example, something clever will replace it.
out <- replicate(iterations,
{
new_knots <- lapply(rngs, function(r) {
kts <- sort(runif(sample(1:5, 1), min = r[1], max = r[2]))
paste0("c(", paste(kts, collapse = ", "), ")")
})
new_spline_terms <-
mapply(FUN = function(s, k) { sub(")$", paste0(", knots = ", k, ")"), s) },
s = old_spline_terms,
k = new_knots)
rhs <- old_rhs
for(i in 1:length(old_spline_terms)) {
rhs <- gsub(old_spline_terms[i], new_spline_terms[i], rhs, fixed = TRUE)
}
f <- as.formula(paste(rownames(attr(old_terms, "factors"))[1], "~", rhs))
environment(f) <- environment(frm)
return(f)
},
simplify = FALSE)
return(out)
}
###
### Example use.
###
# This is a silly example, not meaningful from a statistical stand point, but helpful to illustrate the results.
f <- price ~ ns(carat) * color + bs(depth, degree = 5) + clarity
with_new_knots(f, diamonds)
orig_fit <- predict(lm(f, data = diamonds))
new_fits <- with_new_knots(f, diamonds) %>%
lapply(., function(frm) { predict(lm(frm, data = diamonds)) })
# create a data set for plotting results
dat <- data.frame(orig_fit, new_fits)
names(dat)[2:6] <- paste("new knots", 1:5)
dat <- melt(dat, id.vars = NULL)
dat <- cbind(dat, diamonds)
# create the plot
ggplot(dat) +
aes(x = carat, y = value, color = color, shape = clarity) +
geom_line() +
geom_point(aes(y = price), alpha = 0.1) +
facet_wrap( ~ variable, scale = "free")
ggsave(filename = "~/Pictures/SO_25272387.jpg")
# Thank you, MrFlick, http://stackoverflow.com/users/2372064/mrflick
newknots <- function(form, data, calls=c("bs","ns")) {
nk <- function(x) {
sort(runif(sample(1:5, 1), min = min(data[[x]]), max = max(data[[x]])))
}
rr <- function(x, nk, calls) {
if(is.call(x) && deparse(x[[1]]) %in% calls) {
x$knots = nk(deparse(x[[2]]))
x
} else if (is.recursive(x)) {
as.call(lapply(as.list(x), rr, nk, calls))
} else {
x
}
}
z <- lapply(as.list(form), rr, nk, calls)
z <- eval(as.call(z))
environment(z) <- environment(form)
z
}
f <- price ~ ns(carat, knots = c(2,3)) * color + bs(depth, degree = 5) + clarity
newknots(f, diamonds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment