Skip to content

Instantly share code, notes, and snippets.

@devmotion
Created July 31, 2022 00:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devmotion/a6c3561f6c593160744147e7c5165f62 to your computer and use it in GitHub Desktop.
Save devmotion/a6c3561f6c593160744147e7c5165f62 to your computer and use it in GitHub Desktop.
/*
* One-pass algorithm of `log_sum_exp`.
*/
function log_sum_exp_onepass(x:Real[_]) -> Real {
if length(x) > 0 {
let (mx, r) <- transform_reduce(x, (-inf, 0.0),
\(x:(Real, Real), y:(Real, Real)) -> {
let (xa, xb) <- x;
let (ya, yb) <- y;
if xa > ya {
return (xa, xb + (yb + 1.0) * nan_exp(ya - xa));
} else {
return (ya, yb + (xb + 1.0) * nan_exp(xa - ya));
}
},
\(x:Real) -> {
return (x, 0.0);
});
return mx + log1p(r);
} else {
return -inf;
}
}
/*
* One-pass algorithm of `resample_reduce`.
*/
function resample_reduce_onepass(w:Real[_]) -> (Real, Real) {
if length(w) == 0 {
return (0.0, 0.0);
} else {
let (mw, r, rsq) <- transform_reduce(w, (-inf, 0.0, 0.0),
\(x:(Real, Real, Real), y:(Real, Real, Real)) -> {
let (xa, xb, xc) <- x;
let (ya, yb, yc) <- y;
v:Real;
if xa > ya {
v <- nan_exp(ya - xa);
return (xa, xb + (yb + 1.0)*v, xc + (yc + 1.0)*v*v);
} else {
v <- nan_exp(xa - ya);
return (ya, yb + (xb + 1.0)*v, yc + (xc + 1.0)*v*v);
}
},
\(x:Real) -> {
return (x, 0.0, 0.0);
});
let rp1 <- r + 1.0;
let ess <- rp1*rp1/(rsq + 1.0);
let log_sum_weights <- mw + log1p(r);
return (ess, log_sum_weights);
}
}
/*
* Print scalar.
*/
function print_result(x:Real) {
stdout.print(x);
}
/*
* Print tuple of scalars.
*/
function print_result(x:(Real, Real)) {
let (xa, xb) <- x;
stdout.print("(");
stdout.print(xa);
stdout.print(", ");
stdout.print(xb);
stdout.print(")");
}
// Underflow example
function underflow_example<F>(f:F) {
x:Real[_] <- [1e-20, log(1e-20)];
stdout.print("f([1e-20, log(1e-20)]) = ");
print_result(f(x));
}
program log_sum_exp_underflow(onepass:Boolean) {
if onepass {
stdout.print("f: log_sum_exp_onepass\n");
underflow_example(log_sum_exp_onepass);
} else {
stdout.print("f: log_sum_exp\n");
underflow_example(log_sum_exp);
}
stdout.print(" (correct: ~1.999999999999999999985e-20)\n");
}
program resample_reduce_underflow(onepass:Boolean) {
if onepass {
stdout.print("f: resample_reduce_onepass\n");
underflow_example(resample_reduce_onepass);
} else {
stdout.print("f: resample_reduce\n");
underflow_example(resample_reduce);
}
stdout.print(" (correct: (_, ~1.999999999999999999985e-20))\n");
}
// Timings
function timings<F,G>(f:F, g:G) {
x:Real[1000];
for t in 1..1000 {
x[t] <~ Gaussian(0.0, 1.0);
}
tic();
let y <- f(x);
let elapsed <- toc();
stdout.print("current: ");
print_result(y);
stdout.print(" (result), ");
stdout.print(elapsed);
stdout.print(" (time)\n");
tic();
y <- g(x);
elapsed <- toc();
stdout.print("onepass: ");
print_result(y);
stdout.print(" (result), ");
stdout.print(elapsed);
stdout.print(" (time)\n");
}
program log_sum_exp_timings() {
timings(log_sum_exp, log_sum_exp_onepass);
}
program resample_reduce_timings() {
timings(resample_reduce, resample_reduce_onepass);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment