Skip to content

Instantly share code, notes, and snippets.

@quattro
Created February 3, 2023 17:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save quattro/c0cc02e1e86c6fe8a75e6428dd2a2e5d to your computer and use it in GitHub Desktop.
Save quattro/c0cc02e1e86c6fe8a75e6428dd2a2e5d to your computer and use it in GitHub Desktop.
TraceEnum_ELBO for SuSiE
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse as ap
import logging
import sys
import jax
import jax.numpy as jnp
from jax import jit, nn, random
from jax.example_libraries.optimizers import exponential_decay
from jax.random import PRNGKey
import numpyro
import numpyro.distributions as dist
from numpyro import optim
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO
from numpyro.ops.indexing import Vindex
def simulation(rng_key, n_dim, p_dim, l_dim=2, h2g=0.1):
rng_key, x_key, b_key, s_key, obs_key = random.split(rng_key, 5)
X = random.normal(x_key, shape=(n_dim, p_dim))
X = X - jnp.mean(X, axis=0)
b = jnp.sqrt(h2g / l_dim) * random.normal(b_key, shape=(l_dim,))
S = random.choice(s_key, p_dim, shape=(l_dim,))
S = nn.one_hot(S, p_dim).T
w = S @ b
m = X @ w
s2g = jnp.var(m)
s2e = ((1 / h2g) - 1) * s2g
y = m + jnp.sqrt(s2e) * random.normal(obs_key, shape=(n_dim,))
y = y - jnp.mean(y)
return y, X, S, b
def get_logger(name, path=None):
logger = logging.getLogger(name)
if not logger.handlers:
# Prevent logging from propagating to the root logger
logger.propagate = 0
console = logging.StreamHandler()
logger.addHandler(console)
log_format = "[%(asctime)s - %(levelname)s] %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(fmt=log_format, datefmt=date_format)
console.setFormatter(formatter)
if path is not None:
disk_log_stream = open("{}.log".format(path), "w")
disk_handler = logging.StreamHandler(disk_log_stream)
logger.addHandler(disk_handler)
disk_handler.setFormatter(formatter)
return logger
def ser(X, y, s2, s20):
XtX_diag = jnp.sum(X**2.0, axis=0)
b_hat = (X.T @ y) / XtX_diag
s2hat = s2 / XtX_diag
post_var = 1 / ((1 / s2hat) + (1 / s20))
post_mu = post_var / s2hat * b_hat
C = s2hat / (s2hat + s20)
z = b_hat / jnp.sqrt(s2hat)
# z**2*(-s/(s + s0) + 1)/2 + log(s)/2 - log(s + s0)/2
log_bf = z**2 * (1 - C) + jnp.log(s2hat) - jnp.log(s2hat + s20)
alpha = nn.softmax(log_bf)
return alpha, post_mu, post_var
def ess(X, y, b_l, b_l_2):
mu_l = X @ b_l
mu_l_2 = (X**2) @ b_l_2
var_l = jnp.mean(mu_l_2, axis=0) - mu_l_2
return jnp.sum((y - jnp.sum(mu_l, axis=-1)) ** 2) + jnp.sum(var_l)
def susie(X, y, l_dim=1, max_iter=100, tol=1e-3):
import numpy as np
n_dim, p_dim = jnp.shape(X)
s2e = 1.0
s20 = 1e-2 * np.ones(l_dim)
b_l = np.zeros((p_dim, l_dim))
b_l_2 = np.zeros((p_dim, l_dim))
alpha_l = np.zeros((p_dim, l_dim))
for _ in range(1, max_iter + 1):
r = y - X @ np.sum(b_l, axis=-1)
for l in range(l_dim):
r_l = r + X @ b_l[:, l]
a_tmp, mu_tmp, var_tmp = ser(X, r_l, s2e, s20[l])
alpha_l[:, l] = a_tmp
b_l[:, l] = a_tmp * mu_tmp
b_l_2[:, l] = a_tmp * (var_tmp + mu_tmp**2)
r = r_l - X @ b_l[:, l]
s2e_old = s2e
s2e = ess(X, y, b_l, b_l_2) / n_dim
print(f"s2e = {s2e}")
if jnp.abs(s2e - s2e_old) < tol:
break
return alpha_l, b_l, s2e, s20
def model(X, y, l_dim=1):
n_dim, p_dim = X.shape
pi = jnp.ones(p_dim) / float(p_dim)
sigma_b = jnp.ones(p_dim) * 1e-3
gamma = numpyro.sample("gamma", dist.Categorical(pi))
b = numpyro.sample("b", dist.Normal(0.0, sigma_b[..., gamma]))
sigma_e = numpyro.param("sigma_e", 0.9, constraint=constraints.positive)
with numpyro.plate("N", n_dim):
g = Vindex(X)[:, gamma] * b
numpyro.sample("y", dist.Normal(g, sigma_e), obs=y)
return
def guide(X, y, l_dim=1):
n_dim, p_dim = X.shape
alpha = numpyro.param(
"alpha", jnp.ones(p_dim) / float(p_dim), constraints=constraints.simplex
)
gamma = numpyro.sample(
"gamma", dist.Categorical(alpha), infer={"enumerate": "parallel"}
)
post_sigma_b = numpyro.param(
"post_sigma_b", jnp.ones(p_dim) * 1e-3, constraints=constraints.positive
)
post_mu_b = numpyro.param("post_mu_b", jnp.zeros(p_dim))
b = numpyro.sample(
"b",
dist.Normal(Vindex(post_mu_b)[..., gamma], Vindex(post_sigma_b)[..., gamma]),
)
return
def main(args):
argp = ap.ArgumentParser(description="")
argp.add_argument("-e", "--epochs", type=int, default=10)
argp.add_argument("--l-dim", default=3, type=int, help="size of latent dim")
argp.add_argument("-s", "--seed", type=int, default=0)
argp.add_argument("-d", "--debug", action="store_true", default=False)
argp.add_argument("-r", "--render", action="store_true", default=False)
argp.add_argument("-v", "--verbose", action="store_true", default=False)
argp.add_argument("--device", choices=["cpu", "gpu"], default="gpu")
argp.add_argument("-l", "--learning-rate", default=1e-1, type=float)
args = argp.parse_args(args)
# set up logging
log = get_logger(__name__)
if args.verbose:
log.setLevel(logging.DEBUG)
else:
log.setLevel(logging.INFO)
# set up debugging info if needed
if args.debug:
numpyro.enable_validation()
# ensure 64bit precision
numpyro.enable_x64()
# init key
rng_key = PRNGKey(args.seed)
rng_key, rng_key_init, rng_key_run = random.split(rng_key, 3)
# load data
log.info("Loading simulation data.")
n_dim = 400
p_dim = 20
l_dim = 1
y, X, S, beta = simulation(rng_key_init, n_dim, p_dim, l_dim)
# initialize model and state
log.info("Constructing SVI model.")
adam = optim.Adam(step_size=0.005)
svi = SVI(model, guide, adam, TraceEnum_ELBO(max_plate_nesting=10))
# construct illustration to check dependencies in implementation
if args.render:
numpyro.render_model(
model,
model_args=(X, y, l_dim),
filename="susie.png",
render_distributions=True,
)
# output shapes to check batch/event for variables
rng_key, rng_key_trace = random.split(rng_key, 2)
trace = numpyro.handlers.trace(
numpyro.handlers.seed(model, rng_key_trace)
).get_trace(X=X, y=y, l_dim=args.l_dim)
shapes = numpyro.util.format_shapes(trace, compute_log_prob=True)
print(shapes)
# run inference
results = svi.run(
rng_key_run,
args.epochs,
X=X,
y=y,
l_dim=args.l_dim,
progress_bar=True,
stable_update=True,
)
# convert logits to posterior probabilities per effect (l_dim)
post_p = results.params["alpha"]
# convert to overall posterior inclusion probabilities
PIP_np = 1.0 - jnp.prod(1 - post_p, axis=0)
# perform inference using crude IBSS implementation
alpha_l, b_l, s2e, s20 = susie(X, y, l_dim)
# convert to overall posterior inclusion probabilities
PIP = 1.0 - jnp.prod(1 - alpha_l.T, axis=0)
import pdb
pdb.set_trace()
# helper function to pull credible sets from posteriors
def get_credset(post_p, rho=0.9):
l_dim, p_dim = post_p.shape
idxs = jnp.argsort(-post_p) # flip for decreasing sort
cs_s = []
for ldx in range(l_dim):
cs_s.append([])
local = 0.0
for pdx in range(p_dim):
if local >= rho:
break
idx = idxs[ldx][pdx]
cs_s[ldx].append(idx)
local += post_p[ldx, idx]
return cs_s
# get CS for true values, numpyro, and ibss
cs_true = get_credset(S.T)
cs_np = get_credset(post_p)
cs_ibss = get_credset(alpha_l.T)
log.info(f"True vars = {cs_true}")
log.info(f"Numpyro 90% credible set = {cs_np}")
log.info(f"IBSS 90% credible set = {cs_ibss}")
log.info("Finished Inference")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment