Created
May 17, 2021 04:04
-
-
Save eb8680/de83ab8aa535c38b932bc98303f72a5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
In Funsor it is possible to program entirely without batch dimensions | |
for a subset of PyTorch and Jax (including eliminating the batch/event | |
dimension distinction in the PyTorch and NumPyro/TFP distribution APIs). | |
Here is a simplified version of the mean computation in @murphyk's regression example. | |
""" | |
import torch | |
import funsor | |
funsor.set_backend("torch") | |
raw_data = torch.randn(num_datapoints, in_features) | |
# render the data batch dimension invisible | |
# to model code by giving it a name | |
data = funsor.Tensor(raw_data)["data"] | |
assert data.shape == (in_features,) # behaves like a single datapoint | |
assert set(data.inputs) == {"data"} | |
raw_weights = torch.randn(num_particles, out_features, in_features) | |
# render the elbo particle batch dimension invisible | |
# to model code by giving it a (different) name | |
weights = funsor.Tensor(raw_weights)["particles"] | |
assert weights.shape == (out_features, in_features) # behaves like a single sample | |
assert set(weights.inputs) == {"particles"} | |
raw_bias = torch.randn(num_particles, out_features) | |
bias = funsor.Tensor(raw_bias)["particles"] | |
assert bias.shape == (out_features,) | |
assert set(bias.inputs) == {"particles"} | |
# This line is completely unaware of the batch dimensions on | |
# raw_data and raw_weights and raw_bias, which are automatically | |
# broadcasted and propagated without user intervention | |
mean = weights @ data + bias | |
assert mean.shape == (out_features,) | |
assert set(mean.inputs) == {"data", "particles"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment