Skip to content

Instantly share code, notes, and snippets.

@fritzo
Created December 13, 2018 16:20
Show Gist options
  • Save fritzo/b401684361defc5e5eeec0c40248d886 to your computer and use it in GitHub Desktop.
Save fritzo/b401684361defc5e5eeec0c40248d886 to your computer and use it in GitHub Desktop.
Vectorized conditional in Pyro
# This faster version uses a vectorized pyro.plate and pyro.mask
# to perform vectorized inference.
def vectorized_model(x, y, truncated):
rate = my_regression_function(x)
y_dist = dist.Exponential(rate)
with pyro.plate("data", len(x)):
with pyro.mask(~truncated):
pyro.sample("obs", y_dist, obs=y)
with pyro.mask(truncated):
pyro.sample("truncated_obs",
dist.Bernoulli(y_dist.cdf(y)),
obs=1.)
# This slower version uses a sequential pyro.plate and
# an if statement to perform inference in a loop.
def sequential_model(x, y, truncated):
for i in pyro.plate("data", len(x)):
rate = my_regression_function(x[i])
y_dist = dist.Exponential(rate)
if not truncated[i]:
pyro.sample("obs_{}".format(i), y_dist, obs=y[i])
else:
pyro.sample("truncated_obs_{}".format(i),
dist.Bernoulli(y_dist.cdf(y[i])),
obs=1.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment