Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created June 6, 2024 15:57
Show Gist options
  • Save abikoushi/9b1d9a8703b180d027be1e826cab547a to your computer and use it in GitHub Desktop.
Save abikoushi/9b1d9a8703b180d027be1e826cab547a to your computer and use it in GitHub Desktop.
Doubly-Sparse variational poisson regression (proto-type)
doVB <- function(y,X,a=0.5,b=0.001,iter=200){
K <- ncol(X)
sxy = t(X)%*%y
lambda <- rgamma(K,1,1)
ahat <- rep(a,K)
bhat <- rep(b,K)
for (i in 1:iter) {
for(j in 1:K){
ahat[j] = sxy[j]+a
bhat[j] = X[,j]%*%exp(X[,-j,drop=FALSE]%*%log(lambda[-j]))+b
lambda[j] <- ahat[j]/bhat[j]
}
}
return(list(lambda=lambda,ahat=ahat,bhat=bhat))
}
doVB_sp <- function(y,X,m,p,a=0.5,b=0.001,iter=200){
K <- ncol(X)
sxy = t(X)%*%y
lambda <- rgamma(K,1,1)
ahat <- rep(a,K)
bhat <- rep(b,K)
for (i in 1:iter) {
xs <- cbind(1,rbinom(m,1,p))
for(j in 1:K){
ahat[j] = sxy[j]+a
bhat[j] = X[,j]%*%exp(X[,-j,drop=FALSE]%*%log(lambda[-j]))
bhat[j] = bhat[j] + xs[,j]%*%exp(xs[,-j,drop=FALSE]%*%log(lambda[-j]))
bhat[j] = bhat[j] + b
lambda[j] <- ahat[j]/bhat[j]
}
}
return(list(lambda=lambda,ahat=ahat,bhat=bhat))
}
lambda<-rexp(2)
X0 <- X[y<=0,]
t(X0)%*%exp(X0%*%log(lambda))
t(xs)%*%exp(xs%*%log(lambda))
set.seed(1)
x <- rbinom(100,1,0.5)
X <- model.matrix(~x)
y <- rpois(100,exp(X%*%c(-0.2,1)))
y1 <- y[y>0]
X1 <- X[y>0,]
m <- nrow(X)-nrow(X1)
p<-mean(X[y==0,2])
outVB <- doVB(y,X)
outVB_sp <- doVB_sp(y1,X1,m,p)
w1 <- log(rgamma(5000,outVB$ahat[1],outVB$bhat[1]))
w2 <- log(rgamma(5000,outVB$ahat[2],outVB$bhat[2]))
w1s <- log(rgamma(5000,outVB_sp$ahat[1],outVB_sp$bhat[1]))
w2s <- log(rgamma(5000,outVB_sp$ahat[2],outVB_sp$bhat[2]))
ggplot(data = NULL)+
geom_point(aes(x=w1,y=w2,colour="dens"),alpha=0.1) +
geom_point(aes(x=w1s,y=w2s,colour="sparse"),alpha=0.1) +
theme_bw(15) +
guides(colour = guide_legend(override.aes = list(alpha=1)))+
scale_color_brewer(palette="Set2")+labs(colour="method")
ggplot(data = NULL)+
geom_density2d(aes(x=w1,y=w2,colour="dens")) +
geom_density2d(aes(x=w1s,y=w2s,colour="sparse")) +
theme_bw(15) +
scale_color_brewer(palette="Set2") +
labs(colour="method")
ggsave("dens.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment