-
-
Save quattro/cb0e7ecd10bbd1c41119b5fec311fd11 to your computer and use it in GitHub Desktop.
numpyro susie comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import argparse as ap | |
import logging | |
import sys | |
import jax | |
import jax.numpy as jnp | |
import numpyro | |
import numpyro.distributions as dist | |
from jax import jit, random, nn | |
from jax.experimental.optimizers import exponential_decay | |
from jax.random import PRNGKey | |
from numpyro import optim | |
from numpyro.distributions import constraints | |
from numpyro.infer import TraceGraph_ELBO, Trace_ELBO, SVI | |
def softplus_inv(x): | |
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). | |
Mathematically this op is equivalent to: | |
```none | |
softplus_inverse = log(exp(x) - 1.) | |
Ported from TensorFlow | |
""" | |
threshold = jnp.log(jnp.finfo(x.dtype).eps) + 2.0 | |
is_too_small = x < jnp.exp(threshold) | |
is_too_large = x > -threshold | |
too_small_value = jnp.log(x) | |
too_large_value = x | |
x = jnp.where(is_too_small | is_too_large, jnp.ones([], x.dtype), x) | |
y = x + jnp.log(-jnp.expm1(-x)) # == log(expm1(x)) | |
return jnp.where( | |
is_too_small, too_small_value, jnp.where(is_too_large, too_large_value, y) | |
) | |
def simulation(rng_key, n_dim, p_dim, l_dim=2, h2g=0.1): | |
rng_key, x_key, b_key, s_key, obs_key = random.split(rng_key, 5) | |
X = random.normal(x_key, shape=(n_dim, p_dim)) | |
X = X - jnp.mean(X, axis=0) | |
b = jnp.sqrt(h2g / l_dim) * random.normal(b_key, shape=(l_dim,)) | |
S = random.choice(s_key, p_dim, shape=(l_dim,)) | |
S = nn.one_hot(S, p_dim).T | |
w = S @ b | |
m = X @ w | |
s2g = jnp.var(m) | |
s2e = ((1 / h2g) - 1) * s2g | |
y = m + jnp.sqrt(s2e) * random.normal(obs_key, shape=(n_dim,)) | |
y = y - jnp.mean(y) | |
return y, X, S, b | |
def get_logger(name, path=None): | |
logger = logging.getLogger(name) | |
if not logger.handlers: | |
# Prevent logging from propagating to the root logger | |
logger.propagate = 0 | |
console = logging.StreamHandler() | |
logger.addHandler(console) | |
log_format = "[%(asctime)s - %(levelname)s] %(message)s" | |
date_format = "%Y-%m-%d %H:%M:%S" | |
formatter = logging.Formatter(fmt=log_format, datefmt=date_format) | |
console.setFormatter(formatter) | |
if path is not None: | |
disk_log_stream = open("{}.log".format(path), "w") | |
disk_handler = logging.StreamHandler(disk_log_stream) | |
logger.addHandler(disk_handler) | |
disk_handler.setFormatter(formatter) | |
return logger | |
def ser(X, y, s2, s20): | |
XtX_diag = jnp.sum(X ** 2.0, axis=0) | |
b_hat = (X.T @ y) / XtX_diag | |
s2hat = s2 / XtX_diag | |
post_var = 1 / ((1 / s2hat) + (1 / s20)) | |
post_mu = post_var / s2hat * b_hat | |
C = s2hat / (s2hat + s20) | |
z = b_hat / jnp.sqrt(s2hat) | |
# z**2*(-s/(s + s0) + 1)/2 + log(s)/2 - log(s + s0)/2 | |
log_bf = z ** 2 * (1 - C) + jnp.log(s2hat) - jnp.log(s2hat + s20) | |
alpha = nn.softmax(log_bf) | |
return alpha, post_mu, post_var | |
def ess(X, y, b_l, b_l_2): | |
mu_l = X @ b_l | |
mu_l_2 = (X ** 2) @ b_l_2 | |
var_l = jnp.mean(mu_l_2, axis=0) - mu_l_2 | |
return jnp.sum((y - jnp.sum(mu_l, axis=-1)) ** 2) + jnp.sum(var_l) | |
def susie(X, y, l_dim=1, max_iter=100, tol=1e-3): | |
import numpy as np | |
n_dim, p_dim = jnp.shape(X) | |
s2e = 1.0 | |
s20 = 1e-2 * np.ones(l_dim) | |
b_l = np.zeros((p_dim, l_dim)) | |
b_l_2 = np.zeros((p_dim, l_dim)) | |
alpha_l = np.zeros((p_dim, l_dim)) | |
for _ in range(1, max_iter + 1): | |
r = y - X @ np.sum(b_l, axis=-1) | |
for l in range(l_dim): | |
r_l = r + X @ b_l[:, l] | |
a_tmp, mu_tmp, var_tmp = ser(X, r_l, s2e, s20[l]) | |
alpha_l[:, l] = a_tmp | |
b_l[:, l] = a_tmp * mu_tmp | |
b_l_2[:, l] = a_tmp * (var_tmp + mu_tmp ** 2) | |
r = r_l - X @ b_l[:, l] | |
s2e_old = s2e | |
s2e = ess(X, y, b_l, b_l_2) / n_dim | |
print(f"s2e = {s2e}") | |
if jnp.abs(s2e - s2e_old) < tol: | |
break | |
return alpha_l, b_l, s2e, s20 | |
# define the model | |
def model(X, y, l_dim) -> None: | |
n_dim, p_dim = jnp.shape(X) | |
# fixed prior prob for now | |
logits = jnp.ones(p_dim) | |
beta_prior_log_std = softplus_inv(jnp.ones(l_dim)) | |
with numpyro.plate("susie_plate", l_dim): | |
gamma = numpyro.sample("gamma", dist.Multinomial(logits=logits)) | |
beta_l = numpyro.sample("beta_l", dist.Normal(0.0, nn.softplus(beta_prior_log_std))) | |
# compose the categorical with sampled effects | |
beta = numpyro.deterministic("beta", beta_l @ gamma) | |
loc = X @ beta | |
std = numpyro.param("std", 0.5, constraint=constraints.positive) | |
with numpyro.plate("N", n_dim): | |
numpyro.sample("obs", dist.Normal(loc, std), obs=y) | |
return | |
def guide(X, y, l_dim) -> None: | |
n_dim, p_dim = jnp.shape(X) | |
# posterior for gamma | |
g_shape = (l_dim, p_dim) | |
gamma_logits = numpyro.param("gamma_post_logit", jnp.ones(g_shape)) | |
# posterior for betas | |
b_shape = (p_dim, l_dim) | |
b_loc = numpyro.param("beta_loc", jnp.zeros(b_shape)) | |
b_log_std = numpyro.param("beta_log_std", softplus_inv(jnp.ones(b_shape))) | |
with numpyro.plate("susie_plate", l_dim): | |
gamma = numpyro.sample("gamma", dist.Multinomial(logits=gamma_logits)) | |
# average across individual posterior estimates | |
post_mu = numpyro.deterministic("post_mu", jnp.sum(gamma.T * b_loc, axis=0)) | |
post_std = numpyro.deterministic("post_std", jnp.sum(gamma.T * nn.softplus(b_log_std), axis=0)) | |
beta_l = numpyro.sample("beta_l", dist.Normal(post_mu, post_std)) | |
return | |
def main(args): | |
argp = ap.ArgumentParser(description="") | |
argp.add_argument("-e", "--epochs", type=int, default=6000) | |
argp.add_argument("--l-dim", default=3, type=int, help="size of latent dim") | |
argp.add_argument("-s", "--seed", type=int, default=0) | |
argp.add_argument("-d", "--debug", action="store_true", default=False) | |
argp.add_argument("-v", "--verbose", action="store_true", default=False) | |
argp.add_argument("--device", choices=["cpu", "gpu"], default="gpu") | |
argp.add_argument("-l", "--learning-rate", default=1e-1, type=float) | |
args = argp.parse_args(args) | |
# set up logging | |
log = get_logger(__name__) | |
if args.verbose: | |
log.setLevel(logging.DEBUG) | |
else: | |
log.setLevel(logging.INFO) | |
# set up debugging info if needed | |
if args.debug: | |
numpyro.enable_validation() | |
# ensure 64bit precision | |
numpyro.enable_x64() | |
# init key | |
rng_key = PRNGKey(args.seed) | |
rng_key, rng_key_init, rng_key_run = random.split(rng_key, 3) | |
# load data | |
log.info("Loading simulation data.") | |
n_dim = 400 | |
p_dim = 10 | |
l_dim = 1 | |
y, X, S, beta = simulation(rng_key_init, n_dim, p_dim, l_dim) | |
# initialize model and state | |
log.info("Constructing SVI model.") | |
scheduler = exponential_decay(args.learning_rate, 5000, 0.9) | |
adam = optim.Adam(scheduler) | |
svi = SVI(model, guide, adam, TraceGraph_ELBO()) | |
# construct illustration to check dependencies in implementation | |
""" | |
numpyro.render_model( | |
model, | |
model_args=(X, y, l_dim), | |
filename="susie.png", | |
render_distributions=True, | |
num_tries=10, | |
) | |
""" | |
# output shapes to check batch/event for variables | |
rng_key, rng_key_trace = random.split(rng_key, 2) | |
trace = numpyro.handlers.trace(numpyro.handlers.seed(model, rng_key_trace)).get_trace(X=X, y=y, l_dim=args.l_dim) | |
shapes = numpyro.util.format_shapes(trace, compute_log_prob=True) | |
print(shapes) | |
# run inference | |
results = svi.run( | |
rng_key_run, | |
args.epochs, | |
X=X, | |
y=y, | |
l_dim=args.l_dim, | |
progress_bar=True, | |
stable_update=False, | |
) | |
# convert logits to posterior probabilities per effect (l_dim) | |
post_p = nn.softmax(results.params["gamma_post_logit"]) | |
# convert to overall posterior inclusion probabilities | |
PIP_np = 1.0 - jnp.prod(1 - post_p, axis=0) | |
# perform inference using crude IBSS implementation | |
alpha_l, b_l, s2e, s20 = susie(X, y, l_dim) | |
# convert to overall posterior inclusion probabilities | |
PIP = 1.0 - jnp.prod(1 - alpha_l.T, axis=0) | |
# helper function to pull credible sets from posteriors | |
def get_credset(post_p, rho=0.9): | |
l_dim, p_dim = post_p.shape | |
idxs = jnp.argsort(-post_p) # flip for decreasing sort | |
cs_s = [] | |
for ldx in range(l_dim): | |
cs_s.append([]) | |
local = 0.0 | |
for pdx in range(p_dim): | |
if local >= rho: | |
break | |
idx = idxs[ldx][pdx] | |
cs_s[ldx].append(idx) | |
local += post_p[ldx, idx] | |
return cs_s | |
# get CS for true values, numpyro, and ibss | |
cs_true = get_credset(S.T) | |
cs_np = get_credset(post_p) | |
cs_ibss = get_credset(alpha_l.T) | |
log.info(f"True vars = {cs_true}") | |
log.info(f"Numpyro 90% credible set = {cs_np}") | |
log.info(f"IBSS 90% credible set = {cs_ibss}") | |
log.info("Finished Inference") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment