Skip to content

Instantly share code, notes, and snippets.

@hesenp
Created February 13, 2019 02:25
Show Gist options
  • Save hesenp/90877ea3d5801214098fecd3b9e089bc to your computer and use it in GitHub Desktop.
Save hesenp/90877ea3d5801214098fecd3b9e089bc to your computer and use it in GitHub Desktop.
Modeling Censored Time to Event data using Pyro - modeling with plated data
def model(x, y, truncation_label):
a_model = pyro.sample("a_model", dist.Normal(0, 10))
b_model = pyro.sample("b_model", dist.Normal(0, 10))
link = torch.nn.functional.softplus(a_model * x + b_model)
with pyro.plate("data"):
y_hidden_dist = dist.Exponential(1 / link)
with pyro.poutine.mask(mask = (truncation_label == 0)):
pyro.sample("obs", y_hidden_dist,
obs = y)
with pyro.poutine.mask(mask = (truncation_label == 1)):
truncation_prob = 1 - y_hidden_dist.cdf(y)
pyro.sample("truncation_label",
dist.Bernoulli(truncation_prob),
obs = torch.tensor(1.))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment