Skip to content

Instantly share code, notes, and snippets.

View hesenp's full-sized avatar

Hesen Peng hesenp

View GitHub Profile
@hesenp
hesenp / time-to-event-pyro-example-7.py
Created February 13, 2019 02:27
Modeling Censored Time to Event data using Pyro - plotting stats
N = 1000
for name, quantiles in guide.quantiles(torch.arange(0., N) / N).items():
quantiles = np.array(quantiles)
pdf = 1 / (quantiles[1:] - quantiles[:-1]) / N
x = (quantiles[1:] + quantiles[:-1]) / 2
sns.plt.plot(x, pdf, label=name)
sns.plt.legend()
sns.plt.ylabel('density')
@hesenp
hesenp / time-to-event-pyro-example-6.py
Created February 13, 2019 02:26
Modeling Censored Time to Event data using Pyro - training with SVI
pyro.clear_param_store()
adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
optimizer = optim.Adam(adam_params)
svi = infer.SVI(model,
guide,
optimizer,
loss=infer.Trace_ELBO())
@hesenp
hesenp / time-to-event-pyro-example-5.py
Created February 13, 2019 02:25
Modeling Censored Time to Event data using Pyro - modeling with plated data
def model(x, y, truncation_label):
a_model = pyro.sample("a_model", dist.Normal(0, 10))
b_model = pyro.sample("b_model", dist.Normal(0, 10))
link = torch.nn.functional.softplus(a_model * x + b_model)
with pyro.plate("data"):
y_hidden_dist = dist.Exponential(1 / link)
with pyro.poutine.mask(mask = (truncation_label == 0)):
@hesenp
hesenp / time-to-event-pyro-example-4.py
Created February 13, 2019 02:24
Modeling Censored Time to Event data using Pyro - HMC first run
pyro.clear_param_store()
# note [1]
hmc_kernel = HMC(model,
step_size = 0.1,
num_steps = 4)
# Note [2]
mcmc_run = MCMC(hmc_kernel,
@hesenp
hesenp / time-to-event-pyro-example-3.py
Created February 13, 2019 02:23
Modeling Censored Time to Event data using Pyro - defining naive model
def model(x, y, truncation_label): ## Note [1]
a_model = pyro.sample("a_model", dist.Normal(0, 10)) ## Note [2]
b_model = pyro.sample("b_model", dist.Normal(0, 10))
link = torch.nn.functional.softplus(a_model * x + b_model) ## Note [3]
for i in range(len(x)):
y_hidden_dist = dist.Exponential(1 / link[i]) ## Note [4]
if truncation_label[i] == 0:
@hesenp
hesenp / time-to-event-pyro-example-2.py
Created February 13, 2019 02:21
Modeling Censored Time to Event data using Pyro - generating example data
n = 500
a = 2
b = 4
c = 8
x = dist.Normal(0, 0.34).sample((n,)) # Note [1]
link = torch.nn.functional.softplus(torch.tensor(a*x + b))
# note below, param is rate, not mean
y = dist.Exponential(rate=1 / link).sample()
@hesenp
hesenp / time-to-event-pyro-example-1.py
Last active July 14, 2021 03:12
Modeling Censored Time to Event data using Pyro - importing packages
import pyro
import torch
import seaborn as sns
import pyro.distributions as dist
from pyro import infer, optim
from pyro.infer.mcmc import HMC, MCMC
from pyro.infer import EmpiricalMarginal
assert pyro.__version__.startswith('0.3')