-
-
Save maciejkorzepa/af3b5ef3676d982f8c4298f224960c53 to your computer and use it in GitHub Desktop.
Computation of empirical NTK kernel for a PyTorch model with Linear and Conv2d layers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#### code from torch.nn.grad modified so that the gradients are not summed over batch dimension ###### | |
from torch._six import container_abcs | |
from itertools import repeat | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, container_abcs.Iterable): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
_pair = _ntuple(2) | |
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): | |
r""" | |
Computes the gradient of conv2d with respect to the weight of the convolution. | |
Args: | |
input: input tensor of shape (minibatch x in_channels x iH x iW) | |
weight_size : Shape of the weight gradient tensor | |
grad_output : output gradient tensor (minibatch x out_channels x oH x oW) | |
stride (int or tuple, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 | |
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 | |
Examples:: | |
>>> input = torch.randn(1,1,3,3, requires_grad=True) | |
>>> weight = torch.randn(1,1,1,2, requires_grad=True) | |
>>> output = F.conv2d(input, weight) | |
>>> grad_output = torch.randn(output.shape) | |
>>> grad_weight = torch.autograd.grad(output, filter, grad_output) | |
>>> F.grad.conv2d_weight(input, weight.shape, grad_output) | |
""" | |
stride = _pair(stride) | |
padding = _pair(padding) | |
dilation = _pair(dilation) | |
in_channels = input.shape[1] | |
out_channels = grad_output.shape[1] | |
min_batch = input.shape[0] | |
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1, | |
1) | |
grad_output = grad_output.contiguous().view( | |
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2], | |
grad_output.shape[3]) | |
input = input.contiguous().view(1, input.shape[0] * input.shape[1], | |
input.shape[2], input.shape[3]) | |
grad_weight = torch.conv2d(input, grad_output, None, dilation, padding, | |
stride, in_channels * min_batch) | |
grad_weight = grad_weight.contiguous().view( | |
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2], | |
grad_weight.shape[3]) | |
return grad_weight.view(min_batch, | |
in_channels // groups, out_channels, | |
grad_weight.shape[2], grad_weight.shape[3]).transpose(1, 2).narrow( | |
3, 0, weight_size[2]).narrow(4, 0, weight_size[3]) | |
############### | |
def forward_postprocess(module, input, output): | |
data_input = input[0].detach() | |
def backward_hook_linear(grad): | |
grad = grad.view(grad.shape[0], -1) | |
gg = grad @ grad.t() | |
module.kernel.addcmul_(1., module.A, gg) | |
del module.A | |
def backward_hook_conv(grad): | |
# weights kernel | |
with torch.no_grad(): | |
J = conv2d_weight(data_input, module.weight.shape, grad, stride=module.stride, | |
padding=module.padding, dilation=module.dilation, groups=module.groups) | |
J = J.contiguous().view(J.shape[0], -1) | |
module.kernel.add_(J @ J.t()) | |
# biases kernel | |
if hasattr(module, 'bias') and module.bias.requires_grad: | |
J = grad.sum(dim=[-2, -1]) | |
module.kernel.add_(J @ J.t()) | |
if output.requires_grad and isinstance(module, nn.Linear): | |
h = output.register_hook(backward_hook_linear) | |
data_input = data_input.view(data_input.shape[0], -1) | |
add = 1. if hasattr(module, 'bias') and module.bias.requires_grad else 0. | |
setattr(module, 'A', data_input @ data_input.t() + add) | |
elif output.requires_grad and isinstance(module, nn.Conv2d): | |
h = output.register_hook(backward_hook_conv) | |
def compute_train_kernel(model, x, output): | |
n = x.shape[0] | |
kernel = torch.zeros(n, n, device=x.device) | |
handles = [] | |
for module in model.modules(): | |
if len(list(module.children())) > 0: | |
continue | |
params = [p for p in module.parameters() if p.requires_grad] | |
if len(params) == 0: | |
continue | |
setattr(module, 'kernel', kernel) | |
if forward_postprocess is not None: | |
handles.append(module.register_forward_hook(forward_postprocess)) | |
model.zero_grad() | |
f = model(x) | |
f[:, output].sum().backward() | |
for module in model.modules(): | |
if hasattr(module, 'kernel'): | |
del module.kernel | |
for handle in handles: | |
handle.remove() | |
return kernel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
num_base_chan = 32 | |
class VGGish(nn.Module): | |
def __init__(self): | |
super(VGGish, self).__init__() | |
self.convs = nn.Sequential( | |
nn.Conv2d(3, num_base_chan, kernel_size=3, stride=2, padding=1), nn.ReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(num_base_chan, num_base_chan, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan, num_base_chan, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(num_base_chan, num_base_chan * 2, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 2, num_base_chan * 2, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(num_base_chan * 2, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(num_base_chan * 4, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(), | |
) | |
self.fc = nn.Sequential( | |
nn.Linear(512, 1) | |
) | |
def forward(self, x): | |
x = self.convs(x) | |
x = torch.cat([ | |
F.adaptive_avg_pool2d(x, output_size=1).view(x.shape[0], -1), | |
F.adaptive_max_pool2d(x, output_size=1).view(x.shape[0], -1) | |
], dim=1) | |
x = self.fc(x) | |
return x | |
d = 512 | |
n = 80 | |
inputs = torch.rand(n, 3, d, d).cuda() | |
model = VGGish().cuda() | |
kernel = compute_train_kernel(model, inputs, output=0) | |
# Naive implementation, Jacobians calculated one at a time using autograd to ensure correctness | |
from torch.nn.utils import parameters_to_vector | |
J = [] | |
for x in inputs: | |
model.zero_grad() | |
f = model(x.unsqueeze(0)) | |
f.backward() | |
J.append(torch.cat([m.grad.flatten() for m in model.parameters()])) | |
J = torch.stack(J) | |
# compute relative error using Frobenius norms | |
print(((kernel-kernel_naive)**2).sum().sqrt() / ((kernel_naive)**2).sum().sqrt()) | |
# should print a value <1e-6 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment