Created
December 13, 2018 16:20
-
-
Save fritzo/b401684361defc5e5eeec0c40248d886 to your computer and use it in GitHub Desktop.
Vectorized conditional in Pyro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This faster version uses a vectorized pyro.plate and pyro.mask | |
# to perform vectorized inference. | |
def vectorized_model(x, y, truncated): | |
rate = my_regression_function(x) | |
y_dist = dist.Exponential(rate) | |
with pyro.plate("data", len(x)): | |
with pyro.mask(~truncated): | |
pyro.sample("obs", y_dist, obs=y) | |
with pyro.mask(truncated): | |
pyro.sample("truncated_obs", | |
dist.Bernoulli(y_dist.cdf(y)), | |
obs=1.) | |
# This slower version uses a sequential pyro.plate and | |
# an if statement to perform inference in a loop. | |
def sequential_model(x, y, truncated): | |
for i in pyro.plate("data", len(x)): | |
rate = my_regression_function(x[i]) | |
y_dist = dist.Exponential(rate) | |
if not truncated[i]: | |
pyro.sample("obs_{}".format(i), y_dist, obs=y[i]) | |
else: | |
pyro.sample("truncated_obs_{}".format(i), | |
dist.Bernoulli(y_dist.cdf(y[i])), | |
obs=1.) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment