-
-
Save quattro/c0cc02e1e86c6fe8a75e6428dd2a2e5d to your computer and use it in GitHub Desktop.
TraceEnum_ELBO for SuSiE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import argparse as ap | |
import logging | |
import sys | |
import jax | |
import jax.numpy as jnp | |
from jax import jit, nn, random | |
from jax.example_libraries.optimizers import exponential_decay | |
from jax.random import PRNGKey | |
import numpyro | |
import numpyro.distributions as dist | |
from numpyro import optim | |
from numpyro.distributions import constraints | |
from numpyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO | |
from numpyro.ops.indexing import Vindex | |
def simulation(rng_key, n_dim, p_dim, l_dim=2, h2g=0.1): | |
rng_key, x_key, b_key, s_key, obs_key = random.split(rng_key, 5) | |
X = random.normal(x_key, shape=(n_dim, p_dim)) | |
X = X - jnp.mean(X, axis=0) | |
b = jnp.sqrt(h2g / l_dim) * random.normal(b_key, shape=(l_dim,)) | |
S = random.choice(s_key, p_dim, shape=(l_dim,)) | |
S = nn.one_hot(S, p_dim).T | |
w = S @ b | |
m = X @ w | |
s2g = jnp.var(m) | |
s2e = ((1 / h2g) - 1) * s2g | |
y = m + jnp.sqrt(s2e) * random.normal(obs_key, shape=(n_dim,)) | |
y = y - jnp.mean(y) | |
return y, X, S, b | |
def get_logger(name, path=None): | |
logger = logging.getLogger(name) | |
if not logger.handlers: | |
# Prevent logging from propagating to the root logger | |
logger.propagate = 0 | |
console = logging.StreamHandler() | |
logger.addHandler(console) | |
log_format = "[%(asctime)s - %(levelname)s] %(message)s" | |
date_format = "%Y-%m-%d %H:%M:%S" | |
formatter = logging.Formatter(fmt=log_format, datefmt=date_format) | |
console.setFormatter(formatter) | |
if path is not None: | |
disk_log_stream = open("{}.log".format(path), "w") | |
disk_handler = logging.StreamHandler(disk_log_stream) | |
logger.addHandler(disk_handler) | |
disk_handler.setFormatter(formatter) | |
return logger | |
def ser(X, y, s2, s20): | |
XtX_diag = jnp.sum(X**2.0, axis=0) | |
b_hat = (X.T @ y) / XtX_diag | |
s2hat = s2 / XtX_diag | |
post_var = 1 / ((1 / s2hat) + (1 / s20)) | |
post_mu = post_var / s2hat * b_hat | |
C = s2hat / (s2hat + s20) | |
z = b_hat / jnp.sqrt(s2hat) | |
# z**2*(-s/(s + s0) + 1)/2 + log(s)/2 - log(s + s0)/2 | |
log_bf = z**2 * (1 - C) + jnp.log(s2hat) - jnp.log(s2hat + s20) | |
alpha = nn.softmax(log_bf) | |
return alpha, post_mu, post_var | |
def ess(X, y, b_l, b_l_2): | |
mu_l = X @ b_l | |
mu_l_2 = (X**2) @ b_l_2 | |
var_l = jnp.mean(mu_l_2, axis=0) - mu_l_2 | |
return jnp.sum((y - jnp.sum(mu_l, axis=-1)) ** 2) + jnp.sum(var_l) | |
def susie(X, y, l_dim=1, max_iter=100, tol=1e-3): | |
import numpy as np | |
n_dim, p_dim = jnp.shape(X) | |
s2e = 1.0 | |
s20 = 1e-2 * np.ones(l_dim) | |
b_l = np.zeros((p_dim, l_dim)) | |
b_l_2 = np.zeros((p_dim, l_dim)) | |
alpha_l = np.zeros((p_dim, l_dim)) | |
for _ in range(1, max_iter + 1): | |
r = y - X @ np.sum(b_l, axis=-1) | |
for l in range(l_dim): | |
r_l = r + X @ b_l[:, l] | |
a_tmp, mu_tmp, var_tmp = ser(X, r_l, s2e, s20[l]) | |
alpha_l[:, l] = a_tmp | |
b_l[:, l] = a_tmp * mu_tmp | |
b_l_2[:, l] = a_tmp * (var_tmp + mu_tmp**2) | |
r = r_l - X @ b_l[:, l] | |
s2e_old = s2e | |
s2e = ess(X, y, b_l, b_l_2) / n_dim | |
print(f"s2e = {s2e}") | |
if jnp.abs(s2e - s2e_old) < tol: | |
break | |
return alpha_l, b_l, s2e, s20 | |
def model(X, y, l_dim=1): | |
n_dim, p_dim = X.shape | |
pi = jnp.ones(p_dim) / float(p_dim) | |
sigma_b = jnp.ones(p_dim) * 1e-3 | |
gamma = numpyro.sample("gamma", dist.Categorical(pi)) | |
b = numpyro.sample("b", dist.Normal(0.0, sigma_b[..., gamma])) | |
sigma_e = numpyro.param("sigma_e", 0.9, constraint=constraints.positive) | |
with numpyro.plate("N", n_dim): | |
g = Vindex(X)[:, gamma] * b | |
numpyro.sample("y", dist.Normal(g, sigma_e), obs=y) | |
return | |
def guide(X, y, l_dim=1): | |
n_dim, p_dim = X.shape | |
alpha = numpyro.param( | |
"alpha", jnp.ones(p_dim) / float(p_dim), constraints=constraints.simplex | |
) | |
gamma = numpyro.sample( | |
"gamma", dist.Categorical(alpha), infer={"enumerate": "parallel"} | |
) | |
post_sigma_b = numpyro.param( | |
"post_sigma_b", jnp.ones(p_dim) * 1e-3, constraints=constraints.positive | |
) | |
post_mu_b = numpyro.param("post_mu_b", jnp.zeros(p_dim)) | |
b = numpyro.sample( | |
"b", | |
dist.Normal(Vindex(post_mu_b)[..., gamma], Vindex(post_sigma_b)[..., gamma]), | |
) | |
return | |
def main(args): | |
argp = ap.ArgumentParser(description="") | |
argp.add_argument("-e", "--epochs", type=int, default=10) | |
argp.add_argument("--l-dim", default=3, type=int, help="size of latent dim") | |
argp.add_argument("-s", "--seed", type=int, default=0) | |
argp.add_argument("-d", "--debug", action="store_true", default=False) | |
argp.add_argument("-r", "--render", action="store_true", default=False) | |
argp.add_argument("-v", "--verbose", action="store_true", default=False) | |
argp.add_argument("--device", choices=["cpu", "gpu"], default="gpu") | |
argp.add_argument("-l", "--learning-rate", default=1e-1, type=float) | |
args = argp.parse_args(args) | |
# set up logging | |
log = get_logger(__name__) | |
if args.verbose: | |
log.setLevel(logging.DEBUG) | |
else: | |
log.setLevel(logging.INFO) | |
# set up debugging info if needed | |
if args.debug: | |
numpyro.enable_validation() | |
# ensure 64bit precision | |
numpyro.enable_x64() | |
# init key | |
rng_key = PRNGKey(args.seed) | |
rng_key, rng_key_init, rng_key_run = random.split(rng_key, 3) | |
# load data | |
log.info("Loading simulation data.") | |
n_dim = 400 | |
p_dim = 20 | |
l_dim = 1 | |
y, X, S, beta = simulation(rng_key_init, n_dim, p_dim, l_dim) | |
# initialize model and state | |
log.info("Constructing SVI model.") | |
adam = optim.Adam(step_size=0.005) | |
svi = SVI(model, guide, adam, TraceEnum_ELBO(max_plate_nesting=10)) | |
# construct illustration to check dependencies in implementation | |
if args.render: | |
numpyro.render_model( | |
model, | |
model_args=(X, y, l_dim), | |
filename="susie.png", | |
render_distributions=True, | |
) | |
# output shapes to check batch/event for variables | |
rng_key, rng_key_trace = random.split(rng_key, 2) | |
trace = numpyro.handlers.trace( | |
numpyro.handlers.seed(model, rng_key_trace) | |
).get_trace(X=X, y=y, l_dim=args.l_dim) | |
shapes = numpyro.util.format_shapes(trace, compute_log_prob=True) | |
print(shapes) | |
# run inference | |
results = svi.run( | |
rng_key_run, | |
args.epochs, | |
X=X, | |
y=y, | |
l_dim=args.l_dim, | |
progress_bar=True, | |
stable_update=True, | |
) | |
# convert logits to posterior probabilities per effect (l_dim) | |
post_p = results.params["alpha"] | |
# convert to overall posterior inclusion probabilities | |
PIP_np = 1.0 - jnp.prod(1 - post_p, axis=0) | |
# perform inference using crude IBSS implementation | |
alpha_l, b_l, s2e, s20 = susie(X, y, l_dim) | |
# convert to overall posterior inclusion probabilities | |
PIP = 1.0 - jnp.prod(1 - alpha_l.T, axis=0) | |
import pdb | |
pdb.set_trace() | |
# helper function to pull credible sets from posteriors | |
def get_credset(post_p, rho=0.9): | |
l_dim, p_dim = post_p.shape | |
idxs = jnp.argsort(-post_p) # flip for decreasing sort | |
cs_s = [] | |
for ldx in range(l_dim): | |
cs_s.append([]) | |
local = 0.0 | |
for pdx in range(p_dim): | |
if local >= rho: | |
break | |
idx = idxs[ldx][pdx] | |
cs_s[ldx].append(idx) | |
local += post_p[ldx, idx] | |
return cs_s | |
# get CS for true values, numpyro, and ibss | |
cs_true = get_credset(S.T) | |
cs_np = get_credset(post_p) | |
cs_ibss = get_credset(alpha_l.T) | |
log.info(f"True vars = {cs_true}") | |
log.info(f"Numpyro 90% credible set = {cs_np}") | |
log.info(f"IBSS 90% credible set = {cs_ibss}") | |
log.info("Finished Inference") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment