Skip to content

Instantly share code, notes, and snippets.

@eb8680
Created May 17, 2021 04:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eb8680/de83ab8aa535c38b932bc98303f72a5b to your computer and use it in GitHub Desktop.
Save eb8680/de83ab8aa535c38b932bc98303f72a5b to your computer and use it in GitHub Desktop.
"""
In Funsor it is possible to program entirely without batch dimensions
for a subset of PyTorch and Jax (including eliminating the batch/event
dimension distinction in the PyTorch and NumPyro/TFP distribution APIs).
Here is a simplified version of the mean computation in @murphyk's regression example.
"""
import torch
import funsor
funsor.set_backend("torch")
raw_data = torch.randn(num_datapoints, in_features)
# render the data batch dimension invisible
# to model code by giving it a name
data = funsor.Tensor(raw_data)["data"]
assert data.shape == (in_features,) # behaves like a single datapoint
assert set(data.inputs) == {"data"}
raw_weights = torch.randn(num_particles, out_features, in_features)
# render the elbo particle batch dimension invisible
# to model code by giving it a (different) name
weights = funsor.Tensor(raw_weights)["particles"]
assert weights.shape == (out_features, in_features) # behaves like a single sample
assert set(weights.inputs) == {"particles"}
raw_bias = torch.randn(num_particles, out_features)
bias = funsor.Tensor(raw_bias)["particles"]
assert bias.shape == (out_features,)
assert set(bias.inputs) == {"particles"}
# This line is completely unaware of the batch dimensions on
# raw_data and raw_weights and raw_bias, which are automatically
# broadcasted and propagated without user intervention
mean = weights @ data + bias
assert mean.shape == (out_features,)
assert set(mean.inputs) == {"data", "particles"}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment