Skip to content

Instantly share code, notes, and snippets.

@rhaps0dy
Created July 9, 2020 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rhaps0dy/6174458e99b1a6aa76bf9da6c434d97a to your computer and use it in GitHub Desktop.
Save rhaps0dy/6174458e99b1a6aa76bf9da6c434d97a to your computer and use it in GitHub Desktop.
Natural variational distribution + tests
#!/usr/bin/env python3
import abc
import torch
from gpytorch.distributions import MultivariateNormal
from gpytorch.lazy import CholLazyTensor
from gpytorch.variational._variational_distribution import \
_VariationalDistribution
__all__ = ['NaturalVariationalDistribution', 'TrilNaturalVariationalDistribution']
class _AbstractNVD(_VariationalDistribution, metaclass=abc.ABCMeta):
def __init__(self, num_inducing_points, batch_shape=torch.Size([]),
mean_init_std=1e-3, use_natgrad=True, **kwargs):
super().__init__(num_inducing_points=num_inducing_points,
batch_shape=batch_shape, mean_init_std=mean_init_std)
self._use_natgrad = use_natgrad
m = torch.zeros(num_inducing_points).repeat(*batch_shape, 1)
self.register_parameter("nat_mean", torch.nn.Parameter(m))
self._register_nat_covar(num_inducing_points, batch_shape)
@property
def use_natgrad(self):
return self._use_natgrad
def use_natgrad_(self, use_natgrad=True):
self._use_natgrad = use_natgrad
@abc.abstractmethod
def _register_nat_covar(self, num_inducing_points, batch_shape):
pass
class NaturalVariationalDistribution(_AbstractNVD):
"""
A :obj:`~gpytorch.variational._VariationalDistribution` that is defined to
be a multivariate normal distribution with a full covariance matrix.
Parameterized in terms of its natural parameters, Σ⁻¹μ, -1/2 Σ⁻¹
"""
def _register_nat_covar(self, num_inducing_points, batch_shape):
cov = -.5 * torch.eye(num_inducing_points)
cov = cov.repeat(*batch_shape, 1, 1)
self.register_parameter("nat_covar", torch.nn.Parameter(cov))
def forward(self):
fun = (_NaturalToMuVarSqrt.apply if self.use_natgrad
else _NaturalToMuVarSqrt._forward)
mu, L = fun(self.nat_mean, self.nat_covar)
return MultivariateNormal(mu, CholLazyTensor(L))
def initialize_variational_distribution(self, prior_dist):
chol = prior_dist.lazy_covariance_matrix.cholesky().evaluate()
tril_nat_covar = _triangular_inverse(chol, upper=False)
nat_covar = tril_nat_covar.transpose(-1, -2) @ tril_nat_covar
nat_mean = (prior_dist.mean
.unsqueeze(-1)
.triangular_solve(chol, upper=False, transpose=False).solution
.triangular_solve(chol, upper=False, transpose=True).solution
.squeeze(-1))
self.nat_mean.data.copy_(nat_mean)
# -.5: because nat_covar = -0.5\Sigma. .5: because we're taking the mean
self.nat_covar.data.copy_((nat_covar + nat_covar.transpose(-1, -2)) * (-.5 * .5))
def reparameterise(self):
self.initialize_variational_distribution(self.forward())
class TrilNaturalVariationalDistribution(_AbstractNVD):
"""
A :obj:`~gpytorch.variational._VariationalDistribution` that is defined to
be a multivariate normal distribution with a full covariance matrix.
Parameterized in terms of its natural mean, and a decomposition of the
natural covariance, to ensure the latter stays positive definite when using
BFGS or SGD.
Parameters are: Σ⁻¹μ, L
where L is a lower-triangular matrix, and Σ⁻¹ = LᵀL (Note: this is different than the Cholesky, which is LLᵀ).
Claim: any PD matrix Σ can be represented as LᵀL. Proof by construction:
Calculate cholesky(Σ⁻¹) = Linv. Then
Σ = (Σ⁻¹)⁻¹ = (Linv Linvᵀ)⁻¹ = (Linv⁻¹)ᵀ Linv⁻¹
Since Linv⁻¹ is also lower triangular, set L = Linv⁻¹ and we have found
such a representation for Σ.
"""
def _register_nat_covar(self, num_inducing_points, batch_shape):
cov = torch.eye(num_inducing_points)
cov = cov.repeat(*batch_shape, 1, 1)
self.register_parameter("tril_nat_covar", torch.nn.Parameter(cov))
def forward(self):
fun = (_TrilNaturalToMuVarSqrt.apply if self.use_natgrad
else _TrilNaturalToMuVarSqrt._forward)
mu, L = fun(self.nat_mean, self.tril_nat_covar)
return MultivariateNormal(mu, CholLazyTensor(L))
def initialize_variational_distribution(self, prior_dist):
chol = prior_dist.lazy_covariance_matrix.cholesky().evaluate()
tril_nat_covar = _triangular_inverse(chol, upper=False)
# nat_mean = prior_dist.mean
nat_mean = (prior_dist.mean
.unsqueeze(-1)
.triangular_solve(chol, upper=False, transpose=False).solution
.triangular_solve(chol, upper=False, transpose=True).solution
.squeeze(-1))
self.nat_mean.data.copy_(nat_mean)
self.tril_nat_covar.data.copy_(tril_nat_covar)
def _triangular_inverse(A, upper=False):
eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device)
return eye.triangular_solve(A, upper=upper).solution
def _phi_for_cholesky_(A):
A.tril_().diagonal(offset=0, dim1=-2, dim2=-1).mul_(0.5)
return A
def _cholesky_backward(dout_dL, L, L_inverse):
# c.f. https://github.com/pytorch/pytorch/blob/25ba802ce4cbdeaebcad4a03cec8502f0de9b7b3/tools/autograd/templates/Functions.cpp
A = L.transpose(-1, -2) @ dout_dL
phi = _phi_for_cholesky_(A)
grad_input = (L_inverse.transpose(-1, -2) @ phi) @ L_inverse
# Symmetrize gradient
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5)
class _NaturalToMuVarSqrt(torch.autograd.Function):
@staticmethod
def _forward(nat_mean, nat_covar):
try:
L_inv = torch.cholesky(-2.0 * nat_covar, upper=False)
except RuntimeError as e:
if str(e).startswith("cholesky"):
raise RuntimeError(
"Non-negative-definite natural covariance. You probably "
"updated it using an optimizer other than SGD (such as Adam). "
"This is not supported.")
else:
raise e
L = _triangular_inverse(L_inv, upper=False)
S = L.transpose(-1, -2) @ L
mu = (S @ nat_mean.unsqueeze(-1)).squeeze(-1)
# Two choleskys are annoying, but we don't have good support for a
# LazyTensor of form L.T @ L
return mu, torch.cholesky(S, upper=False)
@staticmethod
def forward(ctx, nat_mean, nat_covar):
mu, L = _NaturalToMuVarSqrt._forward(nat_mean, nat_covar)
ctx.save_for_backward(mu, L)
return mu, L
@staticmethod
def _backward(dout_dmu, dout_dL, mu, L, C):
"""Calculate dout/d(η1, η2), which are:
η1 = μ
η2 = μμᵀ + LLᵀ = μμᵀ + Σ
Thus:
dout/dη1 = dout/dμ + dout/dL dL/dη1
dout/dη2 = dout/dL dL/dη1
For L = chol(η2 - η1⋅η1ᵀ).
dout/dΣ = _cholesky_backward(dout/dL, L)
dout/dη2 = dout/dΣ
dΣ/dη1 = -2* (dout/dΣ) μ
"""
dout_dSigma = _cholesky_backward(dout_dL, L, C)
dout_deta1 = dout_dmu - 2*(dout_dSigma @ mu.unsqueeze(-1)).squeeze(-1)
return dout_deta1, dout_dSigma
@staticmethod
def backward(ctx, dout_dmu, dout_dL):
"Calculates the natural gradient with respect to nat_mean, nat_covar"
mu, L = ctx.saved_tensors
C = _triangular_inverse(L, upper=False)
return _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)
class _TrilNaturalToMuVarSqrt(torch.autograd.Function):
@staticmethod
def _forward(nat_mean, tril_nat_covar):
L = _triangular_inverse(tril_nat_covar, upper=False)
mu = L @ (L.transpose(-1, -2) @ nat_mean.unsqueeze(-1))
return mu.squeeze(-1), L
# return nat_mean, L
@staticmethod
def forward(ctx, nat_mean, tril_nat_covar):
mu, L = _TrilNaturalToMuVarSqrt._forward(nat_mean, tril_nat_covar)
ctx.save_for_backward(mu, L, tril_nat_covar)
return mu, L
@staticmethod
def backward(ctx, dout_dmu, dout_dL):
mu, L, C = ctx.saved_tensors
dout_dnat1, dout_dnat2 = _NaturalToMuVarSqrt._backward(
dout_dmu, dout_dL, mu, L, C)
"""
Now we need to do the Jacobian-Vector Product for the transformation:
L = inv(chol(inv(-2 θ_cov)))
CT C = -2theta_cov
so we need to do forward differentiation, starting with sensitivity:
θ̇_cov = dout_dnat2
and ending with sensitivity Ċ
if B = inv(-2 θ_cov) then:
Ḃ = d inv(-2 θ_cov)/dθ_cov ⋅ θ̇_cov = -B (-2 θ̇_cov) B
if L = chol(B), B = LLᵀ then (https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf):
L̇ = L ϕ(L⁻¹ Ḃ L⁻ᵀ) = L ϕ(2 Lᵀ θ̇_cov L)
Then C = inv(L), so
Ċ = -C L̇ C = ϕ(-2 Lᵀ θ̇_cov L)C
"""
A = L.transpose(-2, -1) @ dout_dnat2 @ L
phi = _phi_for_cholesky_(-2*A)
dout_dtril = phi @ C
return dout_dnat1, dout_dtril
dL = -L @ phi
# Sigma = L @ L.transpose(-1, -2)
# dSigma = dL @ L.transpose(-1, -2) + L @ dL.transpose(-1, -2)
# nat_mean = C.transpose(-1, -2) @ C @ mu
C_mu = C @ mu.unsqueeze(-1)
dout_dmu = ( L @ ( L.transpose(-1, -2) @ dout_dnat1.unsqueeze(-1))
+ dL @ C_mu
+ L @ (dL.transpose(-1, -2) @ (C.transpose(-1, -2) @ C_mu)))
return dout_dmu.squeeze(-1), dout_dtril
import unittest
import torch
import gpytorch
from gpytorch.distributions import MultivariateNormal
from gpytorch.lazy import CholLazyTensor
from natural_variational_distribution import (
NaturalVariationalDistribution, TrilNaturalVariationalDistribution)
torch.set_default_dtype(torch.float64)
class TestNatVariational(unittest.TestCase):
def test_invertible_init(self, D=5):
mu = torch.randn(D)
cov = torch.randn(D, D).tril_()
dist = MultivariateNormal(mu, CholLazyTensor(cov))
v_dist = NaturalVariationalDistribution(D)
v_dist.initialize_variational_distribution(dist)
out_dist = v_dist()
assert torch.allclose(out_dist.mean, dist.mean)
assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix)
def test_natgrad(self, D=5):
mu = torch.randn(D)
cov = torch.randn(D, D).tril_()
dist = MultivariateNormal(mu, CholLazyTensor(cov))
sample = dist.sample()
v_dist = NaturalVariationalDistribution(D)
v_dist.initialize_variational_distribution(dist)
v_dist().log_prob(sample).squeeze().backward()
eta1 = mu.clone().requires_grad_(True)
eta2 = (mu[:, None]*mu + cov@cov.t()).requires_grad_(True)
L = torch.cholesky(eta2 - eta1[:, None]*eta1)
dist2 = MultivariateNormal(eta1, CholLazyTensor(L))
dist2.log_prob(sample).squeeze().backward()
assert torch.allclose(v_dist.nat_mean.grad, eta1.grad)
assert torch.allclose(v_dist.nat_covar.grad, eta2.grad)
def test_optimization_zero_error(self, num_inducing=16, num_data=32, D=2):
inducing_points = torch.randn(num_inducing, D)
class SVGP(gpytorch.models.ApproximateGP):
def __init__(self):
v_dist = NaturalVariationalDistribution(num_inducing)
v_strat = gpytorch.variational.UnwhitenedVariationalStrategy(
self, inducing_points, v_dist)
super().__init__(v_strat)
self.mean_module = gpytorch.means.ZeroMean()
self.covar_module = gpytorch.kernels.RBFKernel()
def forward(self, x):
return MultivariateNormal(self.mean_module(x), self.covar_module(x))
model = SVGP().train()
likelihood = gpytorch.likelihoods.GaussianLikelihood().train()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data)
X = torch.randn((num_data, D))
y = torch.randn(num_data)
def loss():
return -mll(model(X), y)
optimizer = torch.optim.SGD(
model.variational_strategy._variational_distribution.parameters(),
lr=float(num_data))
model.variational_strategy._variational_distribution.use_natgrad_(False)
optimizer.zero_grad()
loss().backward()
grad_nat_mean, grad_nat_covar = (
model.variational_strategy._variational_distribution.nat_mean.grad.clone(),
model.variational_strategy._variational_distribution.nat_covar.grad.clone())
model.variational_strategy._variational_distribution.use_natgrad_(True)
optimizer.zero_grad()
loss().backward()
natgrad_nat_mean, natgrad_nat_covar = (
model.variational_strategy._variational_distribution.nat_mean.grad.clone(),
model.variational_strategy._variational_distribution.nat_covar.grad.clone())
assert not torch.allclose(grad_nat_mean, natgrad_nat_mean)
assert not torch.allclose(grad_nat_covar, natgrad_nat_covar)
optimizer.step() # Now we should be at the optimum
model.variational_strategy._variational_distribution.use_natgrad_(True)
optimizer.zero_grad()
loss().backward()
natgrad_nat_mean2, natgrad_nat_covar2 = (
model.variational_strategy._variational_distribution.nat_mean.grad.clone(),
model.variational_strategy._variational_distribution.nat_covar.grad.clone())
assert torch.allclose(natgrad_nat_mean2, torch.zeros(()))
assert torch.allclose(natgrad_nat_covar2, torch.zeros(()))
model.variational_strategy._variational_distribution.use_natgrad_(False)
optimizer.zero_grad()
loss().backward()
grad_nat_mean, grad_nat_covar = (
model.variational_strategy._variational_distribution.nat_mean.grad.clone(),
model.variational_strategy._variational_distribution.nat_covar.grad.clone())
assert torch.allclose(grad_nat_mean, torch.zeros(()))
assert torch.allclose(grad_nat_covar, torch.zeros(()))
class TestTrilNatVariational(unittest.TestCase):
def test_invertible_init(self, D=5):
mu = torch.randn(D)
cov = torch.randn(D, D).tril_()
dist = MultivariateNormal(mu, CholLazyTensor(cov))
v_dist = TrilNaturalVariationalDistribution(D)
v_dist.initialize_variational_distribution(dist)
out_dist = v_dist()
assert torch.allclose(out_dist.mean, dist.mean)
assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix)
def test_nat_jvp(self, D=5):
mu = torch.randn(D)
cov = torch.randn(D, D)
cov = cov @ cov.t()
dist = MultivariateNormal(mu, CholLazyTensor(cov.cholesky()))
sample = dist.sample()
v_dist = TrilNaturalVariationalDistribution(D)
v_dist.initialize_variational_distribution(dist)
v_dist().log_prob(sample).squeeze().backward()
dout_dnat1 = v_dist.nat_mean.grad
dout_dnat2 = v_dist.tril_nat_covar.grad
v_dist_ref = NaturalVariationalDistribution(D)
v_dist_ref.initialize_variational_distribution(dist)
v_dist_ref().log_prob(sample).squeeze().backward()
dout_dnat1_noforward_ref = v_dist_ref.nat_mean.grad
dout_dnat2_noforward_ref = v_dist_ref.nat_covar.grad
# Use jax for forward-mode AD, and JVPs.
import jax.numpy as np
from jax import jvp
import os
assert os.environ['JAX_ENABLE_X64'] != ""
def f(nat_mean, nat_covar):
"Transform nat_covar to L"
Sigma = np.linalg.inv(-2*nat_covar)
mu = nat_mean
return mu, np.tril(np.linalg.inv(np.linalg.cholesky(Sigma)))
(np_mu, np_tril_nat_covar), (np_dout_dmu_ref, np_dout_dnat2_ref) = jvp(
f,
(np.asarray(v_dist_ref.nat_mean.detach()), np.asarray(v_dist_ref.nat_covar.detach())),
(np.asarray(dout_dnat1_noforward_ref), np.asarray(dout_dnat2_noforward_ref)))
assert np.allclose(
np_tril_nat_covar, v_dist.tril_nat_covar.detach().numpy()), "Sigma transformation"
assert np.allclose(np_dout_dnat2_ref, dout_dnat2.numpy()), "Sigma gradient"
assert np.allclose(np_mu, v_dist.nat_mean.detach().numpy()), "mu transformation"
assert np.allclose(np_dout_dmu_ref, dout_dnat1.numpy()), "mu transformation"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment