numpyro susie comparison
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse as ap
import logging
import sys
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit, random, nn
from jax.experimental.optimizers import exponential_decay
from jax.random import PRNGKey
from numpyro import optim
from numpyro.distributions import constraints
from numpyro.infer import TraceGraph_ELBO, Trace_ELBO, SVI
def softplus_inv(x):
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
Mathematically this op is equivalent to:
softplus_inverse = log(exp(x) - 1.)
Ported from TensorFlow
threshold = jnp.log(jnp.finfo(x.dtype).eps) + 2.0
is_too_small = x < jnp.exp(threshold)
is_too_large = x > -threshold
too_small_value = jnp.log(x)
too_large_value = x
x = jnp.where(is_too_small | is_too_large, jnp.ones([], x.dtype), x)
y = x + jnp.log(-jnp.expm1(-x)) # == log(expm1(x))
return jnp.where(
is_too_small, too_small_value, jnp.where(is_too_large, too_large_value, y)
def simulation(rng_key, n_dim, p_dim, l_dim=2, h2g=0.1):
rng_key, x_key, b_key, s_key, obs_key = random.split(rng_key, 5)
X = random.normal(x_key, shape=(n_dim, p_dim))
X = X - jnp.mean(X, axis=0)
b = jnp.sqrt(h2g / l_dim) * random.normal(b_key, shape=(l_dim,))
S = random.choice(s_key, p_dim, shape=(l_dim,))
S = nn.one_hot(S, p_dim).T
w = S @ b
m = X @ w
s2g = jnp.var(m)
s2e = ((1 / h2g) - 1) * s2g
y = m + jnp.sqrt(s2e) * random.normal(obs_key, shape=(n_dim,))
y = y - jnp.mean(y)
return y, X, S, b
def get_logger(name, path=None):
logger = logging.getLogger(name)
if not logger.handlers:
# Prevent logging from propagating to the root logger
logger.propagate = 0
console = logging.StreamHandler()
log_format = "[%(asctime)s - %(levelname)s] %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(fmt=log_format, datefmt=date_format)
if path is not None:
disk_log_stream = open("{}.log".format(path), "w")
disk_handler = logging.StreamHandler(disk_log_stream)
return logger
def ser(X, y, s2, s20):
XtX_diag = jnp.sum(X ** 2.0, axis=0)
b_hat = (X.T @ y) / XtX_diag
s2hat = s2 / XtX_diag
post_var = 1 / ((1 / s2hat) + (1 / s20))
post_mu = post_var / s2hat * b_hat
C = s2hat / (s2hat + s20)
z = b_hat / jnp.sqrt(s2hat)
# z**2*(-s/(s + s0) + 1)/2 + log(s)/2 - log(s + s0)/2
log_bf = z ** 2 * (1 - C) + jnp.log(s2hat) - jnp.log(s2hat + s20)
alpha = nn.softmax(log_bf)
return alpha, post_mu, post_var
def ess(X, y, b_l, b_l_2):
mu_l = X @ b_l
mu_l_2 = (X ** 2) @ b_l_2
var_l = jnp.mean(mu_l_2, axis=0) - mu_l_2
return jnp.sum((y - jnp.sum(mu_l, axis=-1)) ** 2) + jnp.sum(var_l)
def susie(X, y, l_dim=1, max_iter=100, tol=1e-3):
import numpy as np
n_dim, p_dim = jnp.shape(X)
s2e = 1.0
s20 = 1e-2 * np.ones(l_dim)
b_l = np.zeros((p_dim, l_dim))
b_l_2 = np.zeros((p_dim, l_dim))
alpha_l = np.zeros((p_dim, l_dim))
for _ in range(1, max_iter + 1):
r = y - X @ np.sum(b_l, axis=-1)
for l in range(l_dim):
r_l = r + X @ b_l[:, l]
a_tmp, mu_tmp, var_tmp = ser(X, r_l, s2e, s20[l])
alpha_l[:, l] = a_tmp
b_l[:, l] = a_tmp * mu_tmp
b_l_2[:, l] = a_tmp * (var_tmp + mu_tmp ** 2)
r = r_l - X @ b_l[:, l]
s2e_old = s2e
s2e = ess(X, y, b_l, b_l_2) / n_dim
print(f"s2e = {s2e}")
if jnp.abs(s2e - s2e_old) < tol:
return alpha_l, b_l, s2e, s20
# define the model
def model(X, y, l_dim) -> None:
n_dim, p_dim = jnp.shape(X)
# fixed prior prob for now
logits = jnp.ones(p_dim)
beta_prior_log_std = softplus_inv(jnp.ones(l_dim))
with numpyro.plate("susie_plate", l_dim):
gamma = numpyro.sample("gamma", dist.Multinomial(logits=logits))
beta_l = numpyro.sample("beta_l", dist.Normal(0.0, nn.softplus(beta_prior_log_std)))
# compose the categorical with sampled effects
beta = numpyro.deterministic("beta", beta_l @ gamma)
loc = X @ beta
std = numpyro.param("std", 0.5, constraint=constraints.positive)
with numpyro.plate("N", n_dim):
numpyro.sample("obs", dist.Normal(loc, std), obs=y)
def guide(X, y, l_dim) -> None:
n_dim, p_dim = jnp.shape(X)
# posterior for gamma
g_shape = (l_dim, p_dim)
gamma_logits = numpyro.param("gamma_post_logit", jnp.ones(g_shape))
# posterior for betas
b_shape = (p_dim, l_dim)
b_loc = numpyro.param("beta_loc", jnp.zeros(b_shape))
b_log_std = numpyro.param("beta_log_std", softplus_inv(jnp.ones(b_shape)))
with numpyro.plate("susie_plate", l_dim):
gamma = numpyro.sample("gamma", dist.Multinomial(logits=gamma_logits))
# average across individual posterior estimates
post_mu = numpyro.deterministic("post_mu", jnp.sum(gamma.T * b_loc, axis=0))
post_std = numpyro.deterministic("post_std", jnp.sum(gamma.T * nn.softplus(b_log_std), axis=0))
beta_l = numpyro.sample("beta_l", dist.Normal(post_mu, post_std))
def main(args):
argp = ap.ArgumentParser(description="")
argp.add_argument("-e", "--epochs", type=int, default=6000)
argp.add_argument("--l-dim", default=3, type=int, help="size of latent dim")
argp.add_argument("-s", "--seed", type=int, default=0)
argp.add_argument("-d", "--debug", action="store_true", default=False)
argp.add_argument("-v", "--verbose", action="store_true", default=False)
argp.add_argument("--device", choices=["cpu", "gpu"], default="gpu")
argp.add_argument("-l", "--learning-rate", default=1e-1, type=float)
args = argp.parse_args(args)
# set up logging
log = get_logger(__name__)
if args.verbose:
# set up debugging info if needed
if args.debug:
# ensure 64bit precision
# init key
rng_key = PRNGKey(args.seed)
rng_key, rng_key_init, rng_key_run = random.split(rng_key, 3)
# load data"Loading simulation data.")
n_dim = 400
p_dim = 10
l_dim = 1
y, X, S, beta = simulation(rng_key_init, n_dim, p_dim, l_dim)
# initialize model and state"Constructing SVI model.")
scheduler = exponential_decay(args.learning_rate, 5000, 0.9)
adam = optim.Adam(scheduler)
svi = SVI(model, guide, adam, TraceGraph_ELBO())
# construct illustration to check dependencies in implementation
model_args=(X, y, l_dim),
# output shapes to check batch/event for variables
rng_key, rng_key_trace = random.split(rng_key, 2)
trace = numpyro.handlers.trace(numpyro.handlers.seed(model, rng_key_trace)).get_trace(X=X, y=y, l_dim=args.l_dim)
shapes = numpyro.util.format_shapes(trace, compute_log_prob=True)
# run inference
results =
# convert logits to posterior probabilities per effect (l_dim)
post_p = nn.softmax(results.params["gamma_post_logit"])
# convert to overall posterior inclusion probabilities
PIP_np = 1.0 - - post_p, axis=0)
# perform inference using crude IBSS implementation
alpha_l, b_l, s2e, s20 = susie(X, y, l_dim)
# convert to overall posterior inclusion probabilities
PIP = 1.0 - - alpha_l.T, axis=0)
# helper function to pull credible sets from posteriors
def get_credset(post_p, rho=0.9):
l_dim, p_dim = post_p.shape
idxs = jnp.argsort(-post_p) # flip for decreasing sort
cs_s = []
for ldx in range(l_dim):
local = 0.0
for pdx in range(p_dim):
if local >= rho:
idx = idxs[ldx][pdx]
local += post_p[ldx, idx]
return cs_s
# get CS for true values, numpyro, and ibss
cs_true = get_credset(S.T)
cs_np = get_credset(post_p)
cs_ibss = get_credset(alpha_l.T)"True vars = {cs_true}")"Numpyro 90% credible set = {cs_np}")"IBSS 90% credible set = {cs_ibss}")"Finished Inference")
return 0
if __name__ == "__main__":
