Last active
April 7, 2021 03:45
-
-
Save xmodar/eb32bdb3368c94f9822a5e0fdc29fac2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import torch | |
from torch import nn | |
from torch.autograd import grad | |
from torch.autograd.functional import jacobian, jvp | |
def rand_cov(vector): | |
"""Create a covariance matrix from the specs of a given vector.""" | |
batch_size = vector.shape[0] | |
dimensions = vector.numel() // batch_size | |
matrix = vector.new_empty(batch_size, dimensions, dimensions) | |
return matrix.normal_() @ matrix.transpose(-1, -2) | |
def reduce_rank(matrix): | |
"""Randomly reduce the rank of a matrix.""" | |
u, s, v = matrix.svd() # pylint: disable=invalid-name | |
chop = (..., slice(None, torch.randint(1, s.shape[-1] + 1, ()))) | |
return u[chop] @ s[chop].diag_embed() @ v[chop].transpose(-1, -2) | |
@torch.enable_grad() | |
def jvp_trick(function, point): | |
"""Get a jvp callable with default `create_graph=False`.""" | |
point = point.view_as(point) if point.requires_grad else point.detach() | |
forward = function(point.requires_grad_()) | |
zeros = torch.zeros_like(forward, requires_grad=True) | |
outputs = grad(forward, point, zeros, create_graph=True)[0] | |
return partial(grad, outputs, zeros) | |
def decompose_psd(psd_matrix): | |
"""Decompose a positive semi-definite matrix `A = B @ B.T`.""" | |
eigenvalues, eigenvectors = psd_matrix.symeig(eigenvectors=True) | |
# remove the unnecessary common zero eigenvalues | |
chop = (eigenvalues < 2 * torch.finfo(psd_matrix.dtype).eps).sum(-1).min() | |
if chop > 0: | |
eigenvalues = eigenvalues[..., chop:] | |
eigenvectors = eigenvectors[..., chop:] | |
return eigenvectors * eigenvalues.relu().sqrt_().unsqueeze(-2) | |
def diag_quadratic(b_matrix, func_matrix, point=None): | |
"""Compute `diag(A @ B @ A.T)` for a positive semi-definite matrix B.""" | |
d = decompose_psd(b_matrix) | |
if point is None: # func_matrix is a matrix not a function at point | |
return (func_matrix @ d).square().sum(-1) | |
# jvp_square = lambda x: jvp(func_matrix, point, x)[1].square() | |
_jvp = jvp_trick(func_matrix, point) | |
def jvp_square(v): | |
return _jvp(v, create_graph=True)[0].square() | |
# return sum(jvp_square(v.squeeze(-1)) for v in d.split(1, -1)) | |
ret = 0 | |
for v in d.split(1, -1): | |
ret += jvp_square(v.squeeze(-1)) | |
return ret | |
def expensive_diag_quadratic(matrix, function, point): | |
"""Compute `diag(A @ M @ A.T)` where A is the jacobian at the point.""" | |
index = (range(point.shape[0]),) + (slice(None),) * (point.dim() - 1) | |
jac = jacobian(function, point)[index * 2] | |
return (jac @ matrix @ jac.transpose(-1, -2)).diagonal(0, -1, -2) | |
def my_diag_quadratic(function, inputs, covariance, upper=True): | |
"""Compute `diag(A @ S @ A.T)` where A is the jacobian at the point.""" | |
eigenvalues, eigenvectors = covariance.symeig(eigenvectors=True, upper=upper) | |
sqrt_cov = eigenvectors[..., -1:] * eigenvalues[..., -1:].sqrt_().unsqueeze(-2) | |
# inputs = inputs.view_as(inputs) if inputs.requires_grad else inputs.detach() | |
# forward = function(inputs.requires_grad_()) | |
# zeros = torch.zeros_like(forward, requires_grad=True) | |
# outputs = grad(forward, inputs, zeros, create_graph=True)[0] | |
# return grad(outputs, zeros, sqrt_cov.squeeze(-1), create_graph=True)[0].square() | |
return jvp(function, inputs, sqrt_cov.squeeze(-1))[1].square() | |
if __name__ == "__main__": | |
net = nn.Sequential( | |
nn.Conv2d(1, 3, 3), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(3, 3, 3), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(3, 3, 3), | |
).double() | |
a, b = net[:2], net[2:] | |
mean = a(torch.randn(3, 1, 7, 7).double()) | |
covariance = rand_cov(mean) | |
covariance = reduce_rank(covariance) | |
def f(x): | |
return b(x.view(mean.shape)).flatten(1) | |
out = expensive_diag_quadratic(covariance, f, mean.flatten(1)) | |
print(torch.allclose(out, diag_quadratic(covariance, f, mean.flatten(1)))) | |
if hasattr(f, 'weight'): | |
# pylint: disable=no-member | |
print(torch.allclose(out, diag_quadratic(covariance, f.weight))) | |
our = my_diag_quadratic(f, mean.flatten(1), covariance) | |
print('relative error =', 1 - ((our - out).abs() / out).mean().item()) | |
########################## | |
A = torch.randn(3, 10) # .double() | |
b = torch.randn_like(A[:, 0]) | |
def f(x): | |
return x @ A.t() + b | |
def jvp_fd(model, point, vectors, eps=1e-4): | |
"""Compute the jacobian vector product at point using finite differences.""" | |
h = eps / vectors.norm(dim=-1, keepdim=True) | |
return (model(point + h * v) - model(point)) / h | |
x = torch.randn_like(A[0]) | |
v = torch.randn_like(x) | |
print(jvp_fd(f, x, v)) | |
print(f(v) - b) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
@torch.jit.script | |
def decompose_psd(psd_matrix, eps: float = 1e-6, upper: bool = True): | |
"""Decompose a positive semi-definite matrix `A = B @ B.T`.""" | |
# eps = 2 * torch.finfo(psd_matrix.dtype).eps | |
val, vec = psd_matrix.symeig(eigenvectors=True, upper=upper) | |
chop = (val < eps).sum(-1).min() | |
return vec[..., chop:] * val[..., chop:].relu().sqrt_().unsqueeze(-2) | |
def jvp_fd(model, point, vectors, eps=1e-4, forward=None): | |
"""Compute the jacobian vector product at point using finite differences.""" | |
if forward is None: | |
forward = model(point) | |
delta = eps / vectors.norm(p=2, dim=-1, keepdim=True).clamp_min(eps) | |
return (model(point + delta * vectors) - forward) / delta | |
def diag_quadratic(model, inputs, covariance, eps=1e-4, upper=True): | |
"""Compute `diag(A @ covariance @ A.T)` where A is the jacobian at inputs.""" | |
# equivalent to (A @ decompose_psd(covariance)).norm(p=2, dim=-1) | |
ret = 0 | |
training = model.training | |
model.eval() | |
forward = model(inputs) | |
for column in decompose_psd(covariance, eps / 100, upper).split(1, -1): | |
ret += jvp_fd(model, inputs, column.squeeze(-1), eps, forward).square() | |
model.train(training) | |
return ret | |
if __name__ == "__main__": | |
net = torch.nn.Linear(10, 3) | |
x = torch.randn_like(net.weight[0]) | |
cov = torch.randn(net.in_features, net.in_features) | |
cov = cov @ cov.t() | |
print(diag_quadratic(net, x, cov)) | |
print((net.weight @ cov @ net.weight.t()).diag()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment