Skip to content

Instantly share code, notes, and snippets.

@mbjoseph
Last active September 12, 2016 04:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mbjoseph/25d649e46602a419f9765638d5a2bfbc to your computer and use it in GitHub Desktop.
Save mbjoseph/25d649e46602a419f9765638d5a2bfbc to your computer and use it in GitHub Desktop.
Piecewise regression w/ unknown breakpoint in Stan
data {
int<lower=1> n;
vector[n] x;
vector[n] y;
}
parameters {
real alpha;
vector[2] beta;
real<lower=0> sigma;
real cutpoint;
}
transformed parameters{
vector[n] x2; // indicator variable for whether x_i > cutpoint
for (i in 1:n) {
if (x[i] < cutpoint) {
x2[i] = 0;
} else {
x2[i] = 1;
}
}
}
model {
vector[n] mu;
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 2);
cutpoint ~ normal(0, 1);
for(i in 1:n){
mu[i] = alpha + beta[1] * x[i] + beta[2] * (x[i] - cutpoint) * x2[i];
}
y ~ normal(mu, sigma);
}
library(rstan)
library(scales)
d <- read.csv('regression.db.csv')
stan_d <- list(y = d$direction,
x = c(scale(d$date)),
n = nrow(d))
minit <- stan("lm_segs.stan",
data = stan_d,
chains = 1,
iter = 2)
mfit <- stan(fit = minit ,
iter = 1000,
data = stan_d,
chains = 3,
cores = 3)
mfit
post <- rstan::extract(mfit)
plot(stan_d$x, stan_d$y,
xlab = 'Time', ylab = 'Stuff')
for (i in seq_along(post$lp__)) {
segments(x0 = min(stan_d$x), x1 = post$cutpoint[i],
y0 = post$alpha[i] + post$beta[i, 1] * min(stan_d$x),
y1 = post$alpha[i] + post$beta[i, 1] * post$cutpoint[i],
col = alpha(3, .1))
segments(x1 = max(stan_d$x), x0 = post$cutpoint[i],
y1 = post$alpha[i] +
post$beta[i, 2] * (max(stan_d$x) - post$cutpoint[i]) +
post$beta[i, 1] * max(stan_d$x),
y0 = post$alpha[i] + post$beta[i, 1] * post$cutpoint[i],
col = alpha(2, .1))
}
points(stan_d$x, stan_d$y, pch = 19)
rug(post$cutpoint)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment