Skip to content

Instantly share code, notes, and snippets.

@gpantalos
Created February 4, 2021 13:48
Show Gist options
  • Save gpantalos/94d204a0c44f787ad16e435bda2760e8 to your computer and use it in GitHub Desktop.
Save gpantalos/94d204a0c44f787ad16e435bda2760e8 to your computer and use it in GitHub Desktop.
Laplace Approximation in PyTorch
"""
Laplace approximation of a Beta distribution.
"""
import matplotlib.pyplot as plt
import torch
x = torch.linspace(0, 1, 200)
p = torch.distributions.Beta(2, 5)
def pdf(index_points, distribution):
return torch.exp(distribution.log_prob(index_points))
mode = x[torch.argmax(pdf(x, p))]
plt.plot(x, pdf(x, p))
plt.scatter(mode, pdf(mode, p))
# noinspection PyTypeChecker
hessian = -torch.autograd.functional.hessian(p.log_prob, mode)
q = torch.distributions.Normal(mode, hessian ** -1)
plt.plot(x, pdf(x, q))
plt.title("Laplace Approximation")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment