Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active April 7, 2021 03:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/eb32bdb3368c94f9822a5e0fdc29fac2 to your computer and use it in GitHub Desktop.
Save xmodar/eb32bdb3368c94f9822a5e0fdc29fac2 to your computer and use it in GitHub Desktop.
from functools import partial
import torch
from torch import nn
from torch.autograd import grad
from torch.autograd.functional import jacobian, jvp
def rand_cov(vector):
"""Create a covariance matrix from the specs of a given vector."""
batch_size = vector.shape[0]
dimensions = vector.numel() // batch_size
matrix = vector.new_empty(batch_size, dimensions, dimensions)
return matrix.normal_() @ matrix.transpose(-1, -2)
def reduce_rank(matrix):
"""Randomly reduce the rank of a matrix."""
u, s, v = matrix.svd() # pylint: disable=invalid-name
chop = (..., slice(None, torch.randint(1, s.shape[-1] + 1, ())))
return u[chop] @ s[chop].diag_embed() @ v[chop].transpose(-1, -2)
@torch.enable_grad()
def jvp_trick(function, point):
"""Get a jvp callable with default `create_graph=False`."""
point = point.view_as(point) if point.requires_grad else point.detach()
forward = function(point.requires_grad_())
zeros = torch.zeros_like(forward, requires_grad=True)
outputs = grad(forward, point, zeros, create_graph=True)[0]
return partial(grad, outputs, zeros)
def decompose_psd(psd_matrix):
"""Decompose a positive semi-definite matrix `A = B @ B.T`."""
eigenvalues, eigenvectors = psd_matrix.symeig(eigenvectors=True)
# remove the unnecessary common zero eigenvalues
chop = (eigenvalues < 2 * torch.finfo(psd_matrix.dtype).eps).sum(-1).min()
if chop > 0:
eigenvalues = eigenvalues[..., chop:]
eigenvectors = eigenvectors[..., chop:]
return eigenvectors * eigenvalues.relu().sqrt_().unsqueeze(-2)
def diag_quadratic(b_matrix, func_matrix, point=None):
"""Compute `diag(A @ B @ A.T)` for a positive semi-definite matrix B."""
d = decompose_psd(b_matrix)
if point is None: # func_matrix is a matrix not a function at point
return (func_matrix @ d).square().sum(-1)
# jvp_square = lambda x: jvp(func_matrix, point, x)[1].square()
_jvp = jvp_trick(func_matrix, point)
def jvp_square(v):
return _jvp(v, create_graph=True)[0].square()
# return sum(jvp_square(v.squeeze(-1)) for v in d.split(1, -1))
ret = 0
for v in d.split(1, -1):
ret += jvp_square(v.squeeze(-1))
return ret
def expensive_diag_quadratic(matrix, function, point):
"""Compute `diag(A @ M @ A.T)` where A is the jacobian at the point."""
index = (range(point.shape[0]),) + (slice(None),) * (point.dim() - 1)
jac = jacobian(function, point)[index * 2]
return (jac @ matrix @ jac.transpose(-1, -2)).diagonal(0, -1, -2)
def my_diag_quadratic(function, inputs, covariance, upper=True):
"""Compute `diag(A @ S @ A.T)` where A is the jacobian at the point."""
eigenvalues, eigenvectors = covariance.symeig(eigenvectors=True, upper=upper)
sqrt_cov = eigenvectors[..., -1:] * eigenvalues[..., -1:].sqrt_().unsqueeze(-2)
# inputs = inputs.view_as(inputs) if inputs.requires_grad else inputs.detach()
# forward = function(inputs.requires_grad_())
# zeros = torch.zeros_like(forward, requires_grad=True)
# outputs = grad(forward, inputs, zeros, create_graph=True)[0]
# return grad(outputs, zeros, sqrt_cov.squeeze(-1), create_graph=True)[0].square()
return jvp(function, inputs, sqrt_cov.squeeze(-1))[1].square()
if __name__ == "__main__":
net = nn.Sequential(
nn.Conv2d(1, 3, 3),
nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 3),
nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 3),
).double()
a, b = net[:2], net[2:]
mean = a(torch.randn(3, 1, 7, 7).double())
covariance = rand_cov(mean)
covariance = reduce_rank(covariance)
def f(x):
return b(x.view(mean.shape)).flatten(1)
out = expensive_diag_quadratic(covariance, f, mean.flatten(1))
print(torch.allclose(out, diag_quadratic(covariance, f, mean.flatten(1))))
if hasattr(f, 'weight'):
# pylint: disable=no-member
print(torch.allclose(out, diag_quadratic(covariance, f.weight)))
our = my_diag_quadratic(f, mean.flatten(1), covariance)
print('relative error =', 1 - ((our - out).abs() / out).mean().item())
##########################
A = torch.randn(3, 10) # .double()
b = torch.randn_like(A[:, 0])
def f(x):
return x @ A.t() + b
def jvp_fd(model, point, vectors, eps=1e-4):
"""Compute the jacobian vector product at point using finite differences."""
h = eps / vectors.norm(dim=-1, keepdim=True)
return (model(point + h * v) - model(point)) / h
x = torch.randn_like(A[0])
v = torch.randn_like(x)
print(jvp_fd(f, x, v))
print(f(v) - b)
import torch
@torch.jit.script
def decompose_psd(psd_matrix, eps: float = 1e-6, upper: bool = True):
"""Decompose a positive semi-definite matrix `A = B @ B.T`."""
# eps = 2 * torch.finfo(psd_matrix.dtype).eps
val, vec = psd_matrix.symeig(eigenvectors=True, upper=upper)
chop = (val < eps).sum(-1).min()
return vec[..., chop:] * val[..., chop:].relu().sqrt_().unsqueeze(-2)
def jvp_fd(model, point, vectors, eps=1e-4, forward=None):
"""Compute the jacobian vector product at point using finite differences."""
if forward is None:
forward = model(point)
delta = eps / vectors.norm(p=2, dim=-1, keepdim=True).clamp_min(eps)
return (model(point + delta * vectors) - forward) / delta
def diag_quadratic(model, inputs, covariance, eps=1e-4, upper=True):
"""Compute `diag(A @ covariance @ A.T)` where A is the jacobian at inputs."""
# equivalent to (A @ decompose_psd(covariance)).norm(p=2, dim=-1)
ret = 0
training = model.training
model.eval()
forward = model(inputs)
for column in decompose_psd(covariance, eps / 100, upper).split(1, -1):
ret += jvp_fd(model, inputs, column.squeeze(-1), eps, forward).square()
model.train(training)
return ret
if __name__ == "__main__":
net = torch.nn.Linear(10, 3)
x = torch.randn_like(net.weight[0])
cov = torch.randn(net.in_features, net.in_features)
cov = cov @ cov.t()
print(diag_quadratic(net, x, cov))
print((net.weight @ cov @ net.weight.t()).diag())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment