Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
GP Regression example
functions{
// GP: computes noiseless Gaussian Process
vector GP(real volatility, real amplitude, vector normal01, int n_x, real[] x ) {
matrix[n_x,n_x] cov_mat ;
real amplitude_sq_plus_jitter ;
amplitude_sq_plus_jitter = amplitude^2 + 1e-6 ;
cov_mat = cov_exp_quad(x, amplitude, 1/volatility) ;
for(i in 1:n_x){
cov_mat[i,i] = amplitude_sq_plus_jitter ;
}
return(cholesky_decompose(cov_mat) * normal01 ) ;
}
}
data {
// n_y: number of observations in y
int n_y ;
// y: vector of observations for y
// should be scaled to mean=0,sd=1
vector[n_y] y ;
// n_x: number of unique x values
int n_x ;
// x: unique values of x
// should be scaled to min=0,max=1
real x[n_x] ;
// x_index: vector indicating which x is associated zith each y
int x_index[n_y] ;
// n_z: number of columns in predictor matrix z
int n_z ;
// rows_z_unique: number of rows in predictor matrix z
int rows_z_unique ;
// z_unique: predictor matrix (each column gets its own GP)
matrix[rows_z_unique,n_z] z_unique ;
// z_by_f_index:
int z_by_f_index[n_y] ;
}
transformed data{
matrix[n_z,rows_z_unique] tz = transpose(z_unique);
}
parameters {
// noise: measurement noise
real<lower=0> noise ;
// volatility_helper: helper for cauchy-distributed volatility (see transformed parameters)
vector<lower=0,upper=pi()/2>[n_z] volatility_helper ;
// amplitude: amplitude of GPs
vector<lower=0>[n_z] amplitude ;
// f_normal01: helper variable for GPs (see transformed parameters)
matrix[n_x,n_z] f_normal01 ;
}
transformed parameters{
// volatility: volatility of GPs (a.k.a. inverse-lengthscale)
vector[n_z] volatility ;
// f: GPs
matrix[n_x,n_z] f ;
//next line implies volatility ~ cauchy(0,10)
volatility = 10*tan(volatility_helper) ;
// loop over predictors, computing GPs for each predictor
for(zi in 1:n_z){
f[,zi] = GP(
volatility[zi]
, amplitude[zi]
, f_normal01[,zi]
, n_x , x
) ;
}
}
model {
// noise prior
noise ~ weibull(2,1) ; //peaked at .8ish
// amplitude prior
amplitude ~ weibull(2,1) ; //peaked at .8ish
// normal(0,1) priors on GP helpers
to_vector(f_normal01) ~ normal(0,1);
// loop over observations
y ~ normal(
to_vector(f*tz)[z_by_f_index]
, noise
);
}
# load packages
library(tidyverse)
library(rstan)
rstan_options(auto_write = TRUE)
# load ezStan (if you don't have it, install via: devtools::install_github('mike-lawrence/ezStan') )
library(ezStan)
#ezStan has some nice functions for starting & watching parallel chains,
# as well as a nicer summary table of the posterior samples
# Make some fake data ----
# n_x: number of unique samples on x-axis
n_x = 100
# n_x: number of repeated observations per-x per-condition
n_reps = 10
# prep a tibble with combination of x, conditions & reps
dat = as_tibble(expand.grid(
x = seq(-10,10,length.out=n_x)
, rep = 1:n_reps
, condition = c(-.5,.5)
))
# set random seed for reproducibility
set.seed(1)
# add some columns, eventually leading to observed data
dat %>%
dplyr::mutate(
intercept = sin(x)*dnorm(x,5,8) #arbitrary wiggly curve
, effect = dnorm(x-5,2)/5 #ditto
, true = intercept + effect*condition #combined
, obs = scale(true)[,1] + rnorm(n(),0,1) #true plus noise
) ->
dat
# show the intercept function
dat %>%
ggplot(
mapping = aes(
x = x
, y = intercept
)
)+
geom_line()
# show the effect function
dat %>%
ggplot(
mapping = aes(
x = x
, y = effect
)
)+
geom_line()
# show the condition functions
dat %>%
ggplot(
mapping = aes(
x = x
, y = true
, colour = factor(condition)
, group = factor(condition)
)
)+
geom_line()
# show the noisy observations
dat %>%
ggplot(
mapping = aes(
x = x
, y = obs
, group = rep
)
)+
geom_line(alpha = .1)+
facet_grid(
condition ~ .
)
# show the noisy observations collapsed to means
dat %>%
dplyr::group_by(
x
, condition
) %>%
dplyr::summarise(
m = mean(obs)
) %>%
ggplot(
mapping = aes(
x = x
, y = m
, colour = factor(condition)
, group = factor(condition)
)
)+
geom_line()
# get rid of columns we wouldn't actually have for real data
dat %>%
dplyr::select(
-true
, -intercept
, -effect
) ->
dat
# Prep the data for stan ----
# get the sorted unique value for x
x = sort(unique(dat$x))
# for each value in dat$x, get its index x
x_index = match(dat$x,x)
# compute the model matrix
z = model.matrix(
data = dat
, object = ~ condition
)
# compute the unique entries in the model matrix
temp = as.data.frame(z)
temp = tidyr::unite_(data = temp, col = 'combined', from = names(temp))
temp_unique = unique(temp)
z_unique = z[row.names(z)%in%row.names(temp_unique),]
# for each row in z, get its index z_unique
z_unique_index = match(temp$combined,temp_unique$combined)
# combine the two index objects to get the index into the flattened z_by_f vector
z_by_f_index = x_index + (z_unique_index-1)*length(x)
# create the data list for stan
data_for_stan = list(
n_y = nrow(dat)
, y = scale(dat$obs)[,1] #scaled to mean=0,sd=1
, n_x = length(x)
, x = (x-min(x))/(max(x)-min(x)) #scaled to min=0,max=1
, x_index = x_index
, n_z = ncol(z)
, rows_z_unique = nrow(z_unique)
, z_unique = z_unique
, z_by_f_index = z_by_f_index
)
# model ----
#compile
gp_regression_mod = rstan::stan_model('gp_regression.stan')
# start the parallel chains
ezStan::start_stan(
mod = gp_regression_mod
, data = data_for_stan
, include = FALSE
, pars = c('f_normal01','volatility_helper')
, control = list(
adapt_delta = .99 #GPs tend to need higher-than-default adapt_delta
)
)
#watch the chains' progress
ezStan::watch_stan()
# collect results
post = collect_stan()
# kill just in case
ezStan::kill_stan()
# delete temp folder
ezStan::clean_stan()
#how long did it take?
sort(rowSums(get_elapsed_time(post)/60))
#check noise & GP parameters
ezStan::stan_summary(
from_stan = post
, par = c('noise','volatility','amplitude')
)
#check the rhats for the latent functions
fstats = ezStan::stan_summary(
from_stan = post
, par = 'f'
, return_array = TRUE
)
summary(fstats[,ncol(fstats)]) #rhats
#visualize latent functions
f = rstan::extract(
post
, pars = 'f'
)[[1]]
f2 = tibble::as_tibble(data.frame(matrix(
f
, byrow = F
, nrow = dim(f)[1]
, ncol = dim(f)[2]*dim(f)[3]
)))
f2$sample = 1:nrow(f2)
f2 %>%
tidyr::gather(
key = 'key'
, value = 'value'
, -sample
) %>% #View()
dplyr::mutate(
key = as.numeric(gsub('X','',key))
) %>% #-> temp
dplyr::mutate(
key = as.numeric(gsub('X','',key))
, parameter = rep(
1:dim(f)[3]
, each = dim(f)[1]*dim(f)[2]
)
, x = rep(x,each=dim(f)[1],times=dim(f)[3])
) %>%
dplyr::select(
-key
) ->
fdat
fdat %>%
dplyr::group_by(
x
, parameter
) %>%
dplyr::summarise(
med = median(value)
, lo95 = quantile(value,.025)
, hi95 = quantile(value,.975)
, lo50 = quantile(value,.25)
, hi50 = quantile(value,.75)
) %>%
ggplot()+
geom_hline(yintercept=0)+
geom_ribbon(
mapping = aes(
x = x
, ymin = lo95
, ymax = hi95
)
, alpha = .5
)+
geom_ribbon(
mapping = aes(
x = x
, ymin = lo50
, ymax = hi50
)
, alpha = .5
)+
geom_line(
mapping = aes(
x = x
, y = med
)
, alpha = .5
)+
facet_grid(
parameter ~ .
, scale = 'free_y'
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.