Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created January 13, 2020 22:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/cd13be8c2839cf372796397cea53c202 to your computer and use it in GitHub Desktop.
Save DuaneNielsen/cd13be8c2839cf372796397cea53c202 to your computer and use it in GitHub Desktop.
Example of plotting MV Norm distribution in Pytorch
import torch
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
d = Normal(0.0, 1.0)
x = torch.linspace(-4, 4.0, 50)
y = torch.exp(d.log_prob(x))
plt.plot(x, y)
plt.show()
d = MultivariateNormal(torch.zeros(2), torch.eye(2))
x_, y_, = torch.linspace(-3, 3, 100), torch.linspace(-3, 3, 100)
x, y = torch.meshgrid([x_, y_])
z = d.log_prob(torch.stack((x, y)).T)
z = torch.exp(z)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x.cpu().numpy(), y.cpu().numpy(), z.cpu().numpy())
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment