Skip to content

Instantly share code, notes, and snippets.

@roualdes
Created March 21, 2023 20:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save roualdes/b7aeb3f73cf421bdd502dfa17f7b04ed to your computer and use it in GitHub Desktop.
Save roualdes/b7aeb3f73cf421bdd502dfa17f7b04ed to your computer and use it in GitHub Desktop.
Stan to autodiff effective sample size
functions {
int fft_nextgoodsize(int N) {
if (N <= 2) {
return 2;
}
int m = N;
int n = N;
while (1) {
while (m % 2 == 0) {
m = m / 2;
}
while (m % 3 == 0) {
m = m / 3;
}
while (m % 5 == 0) {
m = m / 5;
}
if (m <= 1) {
return n;
}
n += 1;
}
return n;
}
vector autocovariance(vector x, int N) {
int Mt2 = 2 * fft_nextgoodsize(N);
vector[Mt2] yc = rep_vector(0, Mt2);
yc[1 : N] = x - mean(x);
complex_vector[Mt2] t = inv_fft(to_complex(yc, 0));
complex_vector[Mt2] ac = inv_fft(conj(t) .* t);
return get_real(ac)[1 : N] .* 4 .* N;
}
real ess(vector X, int iterations, int chains) {
matrix[iterations, chains] x;
for (chain in 1:chains) {
x[:, chain] = X[(1 + (chain - 1) * iterations):(chain * iterations)];
}
matrix[iterations, chains] acov;
vector[chains] chain_mean;
for (chain in 1:chains) {
acov[:, chain] = autocovariance(x[:, chain], iterations);
chain_mean[chain] = mean(x[:, chain]);
}
real mean_var = mean(acov[1, :]) * iterations / (iterations - 1);
real var_plus = mean_var * (iterations - 1) / iterations;
if (chains > 1) {
var_plus += variance(chain_mean);
}
vector[iterations] rhohat = zeros_vector(iterations);
int t = 0;
real rhohat_even = 1.0;
rhohat[t + 1] = rhohat_even;
real rhohat_odd = 1.0 - (mean_var - mean(acov[t + 2, :])) / var_plus;
rhohat[t + 2] = rhohat_odd;
while (t < (iterations - 5) && !is_nan(rhohat_even + rhohat_odd) && (rhohat_even + rhohat_odd) > 0) {
t += 2;
rhohat_even = 1.0 - (mean_var - mean(acov[t + 1, :])) / var_plus;
rhohat_odd = 1.0 - (mean_var - mean(acov[t + 2, :])) / var_plus;
if ((rhohat_even + rhohat_odd) >= 0) {
rhohat[t + 1] = rhohat_even;
rhohat[t + 2] = rhohat_odd;
}
}
int max_t = t;
if (rhohat_even > 0) {
rhohat[max_t + 1] = rhohat_even;
}
t = 0;
while (t <= (max_t - 4)) {
t += 2;
if ((rhohat[t + 1] + rhohat[t + 2]) > (rhohat[t - 1] + rhohat[t])) {
rhohat[t + 1] = (rhohat[t - 1] + rhohat[t]) / 2.0;
rhohat[t + 2] = rhohat[t + 1];
}
}
real essv = chains * iterations;
real tau = -1.0 + 2.0 * sum(rhohat[1:max(1, max_t)]) + rhohat[max_t + 1];
tau = fmax(tau, 1.0 / log10(essv));
return essv / tau;
}
}
data {
int<lower=0> N;
int<lower=0> M;
}
transformed data {
int NM = N * M;
}
parameters {
vector[NM] x;
}
model {
target += ess(x, N, M);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment