Skip to content

Instantly share code, notes, and snippets.

@MatsuuraKentaro
Created October 26, 2016 08:03
Show Gist options
  • Save MatsuuraKentaro/7e8b483c843e03c4f472662c46f7f719 to your computer and use it in GitHub Desktop.
Save MatsuuraKentaro/7e8b483c843e03c4f472662c46f7f719 to your computer and use it in GitHub Desktop.
Bayesian GPLVM
data {
int N;
int K;
int D;
vector[N] Y[K];
}
transformed data {
vector[N] Mu;
Mu = rep_vector(0, N);
}
parameters {
matrix[D,N] x;
vector<lower=0>[K] theta[5];
}
model {
matrix[N,N] cov[K];
vector[D] x_copy[N];
for (n in 1:N)
x_copy[n] = x[,n];
for (k in 1:K)
cov[k] = cov_exp_quad(x_copy, theta[1,k], theta[2,k]) +
theta[3,k] +
theta[4,k]*crossprod(x) +
diag_matrix(rep_vector(theta[5,k], N));
to_vector(x) ~ normal(0, 1);
for (i in 1:5)
theta[i] ~ student_t(4, 0, 5);
for (k in 1:K)
Y[k] ~ multi_normal(Mu, cov[k]);
}
library(R.matlab)
library(rstan)
library(ggplot2)
d <- readMat('input/3Class.mat')
N <- nrow(d$DataTrn)
K <- ncol(d$DataTrn)
D <- 2
## PCA
res_pca <- prcomp(d$DataTrn)
d_pca <- data.frame(X1=res_pca$x[ ,1], X2=res_pca$x[ ,2])
d_pca$class <- as.factor(apply(d$DataTrnLbls, 1, function(x) which(x == 1)))
p <- ggplot(data=d_pca, aes(x=X1, y=X2, color=class))
p <- p + geom_point(size=1.5, alpha=0.5)
ggsave(p, file='output/result-pca.png', dpi=300, w=5, h=4)
## Bayesian GPLVM
data <- list(N=N, K=K, D=D, Y=t(scale(d$DataTrn, scale=FALSE)))
stanmodel <- stan_model(file='model/model.stan')
fit_vb <- vb(stanmodel, data=data, init=function(){ list(x=t(res_pca$x[,1:D])) },
seed=123, eta=1, adapt_engaged=FALSE)
ms <- extract(fit_vb)
x_est <- t(apply(ms$x, c(2,3), median))
d_gplvm <- data.frame(x_est, class=d_pca$class)
p <- ggplot(data=d_gplvm, aes(x=X1, y=X2, color=class))
p <- p + geom_point(size=1.5, alpha=0.5)
ggsave(p, file='output/result-bgplvm.png', dpi=300, w=5, h=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment