Skip to content

Instantly share code, notes, and snippets.

@universome
Created April 9, 2019 11:36
Show Gist options
  • Save universome/7316a0e456496e7be7cbd01110efc853 to your computer and use it in GitHub Desktop.
Save universome/7316a0e456496e7be7cbd01110efc853 to your computer and use it in GitHub Desktop.
from torch.distributions import MultivariateNormal
from torch.distributions import kl_divergence
mean = torch.rand(10, 100)
logstds = torch.rand(10, 100)
x = torch.rand(10, 100)
dist = MultivariateNormal(mean, scale_tril=torch.stack([torch.diag((ls).exp()) for ls in logstds]))
print('Testing LL')
print(dist.log_prob(x).mean())
print(gaussian_log_likelihood(x, mean, logstds).mean())
print('Testing KL')
standard_normal = MultivariateNormal(torch.zeros_like(mean), scale_tril=torch.stack([torch.eye(mean.size(1)) for _ in range(mean.size(0))]))
print(kl_divergence(dist, standard_normal).mean())
print(KL_divergence(mean, logstds).mean())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment