Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Last active December 14, 2022 04:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abikoushi/b9f41624d2b309da46988522ab6e9535 to your computer and use it in GitHub Desktop.
Save abikoushi/b9f41624d2b309da46988522ab6e9535 to your computer and use it in GitHub Desktop.
R implementation of Polson 2013 "Bayesian inference for logistic models using Polya-Gamma latent variables" (multinomial version)
library(BayesLogit)
library(ggplot2)
#########
softmax <- function(x){
m <- max(x)
u <- x-m
exp(u)/sum(exp(u))
}
logsumexp <- function(x){
m <- max(x)
m + log(sum(exp(x-m)))
}
#########
#Y: response variable
#X: explanatory design matrix
#lambda: prior parameter
gibbs_mlogit <- function(Y, X, iter=1000, lambda=1){
M <- rowSums(Y)
K <- ncol(Y)
N <- nrow(Y)
ydif <- sweep(Y,1,0.5*M)
D <- ncol(X)
Lambda <- diag(lambda, D)
W_hist <- array(0, dim = c(D,K,iter))
W_tilde <- array(0, dim = c(D,K))
W_tilde[(D+1):(D*K)] <- rnorm(D*K-D)
#W_tilde[,-1] <- W[,-1]
for(i in 1:iter){
for(j in 2:K){
c_j <- apply(X%*%W_tilde[,-j,drop=FALSE], 1, logsumexp)
eta <- drop(X%*%W_tilde[,j]-c_j)
omega <- rpg(N, M, eta)
Vinv <- t(X) %*% sweep(X,1,omega,"*") + Lambda #equivalent to #t(X)%*%diag(omega)%*%X + Lambda
U <- chol(Vinv)
A <- forwardsolve(t(U), t(X)%*%(ydif[,j,drop=FALSE] + c_j*omega)) #equivalent to #mu <- solve(Vinv%*%(t(X)%*%(ydif[,j,drop=FALSE] + c_j*omega)))
mu <- backsolve(U,A) #multiply to inverse of U
W_tilde[,j] <- mu + backsolve(U, rnorm(D))
W_hist[,j,i] <- W_tilde[,j]
}
}
return(W_hist)
}
set.seed(123456)
W <- matrix(0,2,3)
W[3:6] <- runif(4,-1,1)
x <- rnorm(50,0,1)
X <- cbind(1,x)
prob <- apply(X%*%W,1,softmax)
Y <- t(apply(prob, 2, function(p)rmultinom(1,1000,p)))
out <- gibbs_mlogit(Y, X, iter = 2000, lambda = 1)
dfs <- expand.grid(row=1:2,col=1:2,iter=1:2000)
dfs$value <- as.vector(out[,-1,])
dft <- expand.grid(row=1:2,col=1:2)
dft$value <- c(W[,-1])
ggplot(dfs, aes(x=iter, y=value))+
geom_line(colour="grey")+
geom_hline(data = dft,aes(yintercept=value), colour="royalblue")+
facet_grid(row~col, scales = "free", labeller = label_both)+
theme_classic(14)+
theme(strip.text.y = element_text(angle = 0),
axis.text = element_text(colour = "black"))
#ggsave("trace.png")
burnin <- 1:500
What <-apply(out[,,-burnin], 1:2, mean) #Expectation A Posteriori
print(W)
print(What)
fit=t(apply(X%*%What,1,softmax)) #plugin-predictor
obs <- Y/rowSums(Y)
dff <- data.frame(fitted=c(fit),
observed=c(obs),
k=rep(1:ncol(Y),each=nrow(Y)))
ggplot(dff,aes(x=fitted,y=observed))+
geom_point()+
geom_abline(slope = 1, intercept = 0, linetype=2)+
facet_wrap(~k)+
theme_classic(16)
#ggsave("fit.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment