Skip to content

Instantly share code, notes, and snippets.

@alstat
Created March 20, 2019 13:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alstat/2b4fb3c560cc878921e0a5516415394b to your computer and use it in GitHub Desktop.
Save alstat/2b4fb3c560cc878921e0a5516415394b to your computer and use it in GitHub Desktop.
# Simulating the data
set.seed(73735911)
#set.seed(737377911)
n=100;nu=5;alpha=2;beta=2;sig=1;true=c(alpha,beta,nu)
x=rnorm(n,1,1)
y=alpha+beta*x+sig*rt(n,nu)
par(mfrow=c(1,1))
plot(x,y, col='blue3', pch=19)
# Prior hyperparameters
alpha0=0;taua=10;beta0=0;taub=10
# Setting up starting values
a=0;b=0;count=0;nu=15;count1=0
# Random walk Metropolis
### set stepsize ##################################
chain <- function (M = 6000) {
sda=0.2;sdb=0.2;stepdf=1.9
draws = matrix(0,M,3)
for (i in 1:M){
a1 = rnorm(1,a,sda)
b1 = rnorm(1,b,sdb)
num = prod(dt((y-a1-b1*x)/sig,nu))*dnorm(a1,alpha0,taua)*dnorm(b1,beta0,taub)
den = prod(dt((y-a-b*x) /sig,nu))*dnorm(a, alpha0,taua)*dnorm(b, beta0,taub)
# acc = min(1,num/den)
u = runif(1)
if (u*den < num){a=a1;b=b1;count=count+1}
repeat {
in_df=(nu+rnorm(1)*stepdf)^(-1)
if (in_df > 0.025 & in_df < 0.25){df_2=1/in_df;break}
}
num1 = prod(dt((y-a-b*x)/sig,df_2))
den1 = prod(dt((y-a-b*x)/sig,nu))
# acc2 = min(1,num1/den1)
u1 = runif(1)
if (u1* den1 <num1 ){nu=df_2;count1=count1+1}
draws[i,] = c(a,b,nu)
if(i%%500 == 0)
{
cat(i,"\n")
cat("beta",a,b,"\n")
cat("accept. rate",100*count/i,"\n")
cat("df",nu,"\n")
cat("accept. rate",100*count1/i,"\n")
}
}
return(draws)
}
draw_nchains <- function(n = 100, M = 6000, burn_in = 2001) {
out <- chain(M)
means <- apply(out[burn_in:M,], 2, mean)
sds <- apply(out, 2, sd)
for (i in 2:n) {
out <- chain(M)
means <- rbind(means, apply(out[burn_in:M,], 2, mean))
sds <- rbind(sds, apply(out, 2, sd))
}
colnames(means) <- c("mean_alpha", "mean_beta", "mean_df")
colnames(sds) <- c("sd_alpha", "sd_beta", "sd_df")
rownames(means) <- rownames(sds) <- paste("chain-", 1:n, sep="")
return(list(means, sds))
}
# draw 5 chains
output <- draw_nchains(n = 5)
output
# draw 100 chains
output <- draw_nchains()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment