Skip to content

Instantly share code, notes, and snippets.

@gavinsimpson
Created November 8, 2019 01:36
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gavinsimpson/3d9a8ddccb34560451e0c8c6fd68ec02 to your computer and use it in GitHub Desktop.
Save gavinsimpson/3d9a8ddccb34560451e0c8c6fd68ec02 to your computer and use it in GitHub Desktop.
Animation showing how weighted basis functions combine to yield a cubic regression spline
## Procuces an animation showing how weighted basis functions combine to produce a sline
## - this version is for a cubic regression spline basis
## packages required
library('ggplot2')
library('tibble')
library('tidyr')
library('dplyr')
library('mgcv')
library('mvnfast')
library('purrr')
library('gganimate')
library('gifski')
library('transformr')
library('png')
theme_set(theme_minimal())
##' wrapper around mvnfast::rmvn for coefs for basis functions
##'
##' @param n numeric, number of draws from MVN
##' @param k numeric, number of basis functions in basis
##' @param mu, sigma numeric vectors of means and **variances** for $\beta$
draw_beta <- function(n, k, mu = 1, sigma = 1) {
mvnfast::rmvn(n = n, mu = rep(mu, k), sigma = diag(rep(sigma, k)))
}
##' Given a set of basis functions evaluated at `x`, weight them by coefs
##'
##' @param bf matrix od basis functions evaluated at `x`
##' @param x covariate values used to evaluated basis
##' @param n numeric, number of draws from MVN; probably should be 1
##' @param k, numeric; number of basis functions
##' @param ... additional arguments passed to `draw_beta`
weight_basis <- function(bf, x, n = 1, k, ...) {
beta <- draw_beta(n = n, k = k, ...)
out <- sweep(bf, 2L, beta, '*')
colnames(out) <- paste0('f', seq_along(beta))
out <- as_tibble(out)
out <- add_column(out, x = x)
out <- pivot_longer(out, -x, names_to = 'bf', values_to = 'y')
out
}
##' Generate a set of draws from prior of a spline
##'
##' @param bf matrix of basis functions evaluated at `x`
##' @param x covariate values at which basis was evaluated
##' @param draws numeric; how many draws from the prior to make
##' @param k numeric; number of basis functions in basis
##' @param ... arguments passed to `weight_basis` and then on to `draw_beta`
random_bases <- function(bf, x, draws = 10, k, ...) {
out <- rerun(draws, weight_basis(bf, x = x, k = k, ...))
out <- bind_rows(out)
out <- add_column(out, draw = rep(seq_len(draws), each = length(x) * k),
.before = 1L)
class(out) <- c("random_bases", class(out))
out
}
##' Plot method butchery for a set of draws
##'
##' Use `facet = FALSE` for animations
plot.random_bases <- function(x, facet = FALSE) {
plt <- ggplot(x, aes(x = x, y = y, colour = bf)) +
geom_line(lwd = 1) +
guides(colour = FALSE)
if (facet) {
plt + facet_wrap(~ draw)
}
plt
}
## generate some data/covariate values to work with
set.seed(1)
N <- 500
data <- tibble(x = runif(N))
## Generate the basis
k <- 10 # number of basis functions
## CRS needs knot locations given `k`, arrange evenly over `x`
knots <- with(data, list(x = seq(min(x), max(x), length = k)))
## generate a CRS basis using *mgcv* evaluated at `x`
sm <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X
colnames(sm) <- levs <- paste0("f", seq_len(k))
## convert to tidy form
basis <- pivot_longer(cbind(sm, data), -x, names_to = 'bf')
basis
## generate draws from basis functions
set.seed(2)
bfuns <- random_bases(sm, data$x, draws = 20, k = k)
## given draws from basis functions & random betas, summarise
## to a spline for each draw
smooth <- bfuns %>%
group_by(draw, x) %>%
summarise(spline = sum(y)) %>%
ungroup()
## base plot for animation
p <- plot(bfuns) + geom_line(data = smooth, aes(x = x, y = spline),
inherit.aes = FALSE, lwd = 1.5)
## animate
animate(
p + transition_states(draw, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment