Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created June 10, 2024 07:48
Show Gist options
  • Save abikoushi/e63182afec6afe5a6fce34fb27833ea4 to your computer and use it in GitHub Desktop.
Save abikoushi/e63182afec6afe5a6fce34fb27833ea4 to your computer and use it in GitHub Desktop.
Doubly-Sparse variational poisson regression (binary sampling)
library(ggplot2)
library(dplyr)
doVB <- function(y,X,a=0.5,b=0.01,iter=100){
K <- ncol(X)
sxy = t(X)%*%y
lambda <- rgamma(K,1,1)
ahat <- rep(a,K)
bhat <- rep(b,K)
for (i in 1:iter) {
for(j in 1:K){
ahat[j] = sxy[j]+a
bhat[j] = X[,j]%*%exp(X[,-j,drop=FALSE]%*%log(lambda[-j]))+b
lambda[j] <- ahat[j]/bhat[j]
}
}
return(list(lambda=lambda,ahat=ahat,bhat=bhat))
}
doVB_sp <- function(y,X,m,p,a=0.5,b=0.01,iter=100){
K <- ncol(X)
sxy = t(X)%*%y
lambda <- rgamma(K,1,1)
ahat <- rep(a,K)
bhat <- rep(b,K)
lp <- numeric(iter)
for (i in 1:iter) {
for(j in 1:K){
ahat[j] = sxy[j]+a
#lp[i] = lp[i] - den
den = drop(X[,j]%*%exp(X[,-j,drop=FALSE]%*%log(lambda[-j])))
bhat[j] = den
#lp[i] = lp[i] - den
for(k in 1:m){
s <- rbinom(K,1,p)
den <- s[j]*exp(s[-j]%*%log(lambda[-j]))
bhat[j] = bhat[j] + den
#lp[i] = lp[i] - den
}
bhat[j] = bhat[j] + b
lambda[j] <- ahat[j]/bhat[j]
}
}
return(list(lambda=lambda, ahat=ahat, bhat=bhat))
}
set.seed(1234)
lambda<-rexp(20,1)
X <- matrix(rbinom(200*20,1,0.5),200,20)
y <- rpois(200,exp(X%*%log(lambda)))
X0 <- X[y==0L,]
y1 <- y[y>0]
X1 <- X[y>0,]
(m <- nrow(X)-nrow(X1))
p <-colMeans(X[y==0,])
outVB <- doVB(y,X)
outVB_sp <- doVB_sp(y1,X1,m,p)
df <- data.frame(true=lambda,
est1=outVB$ahat/outVB$bhat,
est2=outVB_sp$ahat/outVB_sp$bhat,
logest1=digamma(outVB$ahat)-log(outVB$bhat),
logest2=digamma(outVB_sp$ahat)-log(outVB_sp$bhat)) %>%
mutate(win = if_else((log(true)-logest1)^2 < (log(true)-logest2)^2,"conventional","proposed"))
ggplot(df,aes(x=true))+
geom_linerange(aes(ymin=est1,ymax=est2, colour=win),alpha=0.7)+
geom_point(aes(y=est1,colour="conventional"),alpha=0.7)+
geom_point(aes(y=est2,colour="proposed"),alpha=0.7)+
geom_abline(intercept=0,slope=1,linetype=2)+
theme_bw()+labs(y="estimates",colour="method")
ggplot(df,aes(x=log(true)))+
geom_linerange(aes(ymin=logest1,ymax=logest2,
colour=win),alpha=0.7)+
geom_point(aes(y=logest1,colour="conventional"),alpha=0.7)+
geom_point(aes(y=logest2,colour="proposed"),alpha=0.7)+
geom_abline(intercept=0,slope=1,linetype=2)+
theme_bw()+labs(y="estimates (log)", colour="method")
cat("MSE (log)\n",
"conventional: ",
sqrt(mean((log(lambda)-(digamma(outVB$ahat) - log(outVB$bhat)))^2)),
" sparse: ",
sqrt(mean((log(lambda)-(digamma(outVB_sp$ahat) - log(outVB_sp$bhat)))^2)))
cat("MSE\n",
"conventional: ",
sqrt(mean((lambda-(outVB$ahat/(outVB$bhat)))^2)),
" sparse: ",
sqrt(mean(lambda-(outVB_sp$ahat/outVB_sp$bhat))^2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment