Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created January 16, 2020 21:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/1743f2bb3f68fbcc16476befcfbb4a4d to your computer and use it in GitHub Desktop.
Save DuaneNielsen/1743f2bb3f68fbcc16476befcfbb4a4d to your computer and use it in GitHub Desktop.
Constructing batched distributions from data in pytorch and sampling from them
import torch
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
"""
Example of computing c batched Normal and Multivariate distributions from data
and sampling batches from them
"""
c = 3
n = 10
x = torch.randn(n, c)
loc = x.mean(dim=0)
scale = x.std(dim=0)
dist = Normal(loc, scale)
samples = dist.sample((n,))
print(dist.mean)
print(dist.stddev)
print(samples.shape)
c = 4
n = 3
d = 2
x = torch.randn(n, c, d)
loc = x.mean(0)
delta = x - loc.unsqueeze(0)
covariance_matrix = torch.matmul(delta.permute(1, 2, 0), delta.permute(1, 0, 2)) / (n - 1)
dist = MultivariateNormal(loc, covariance_matrix)
samples = dist.sample((n, ))
print(samples.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment