Skip to content

Instantly share code, notes, and snippets.

@maciejkorzepa
Last active March 23, 2024 09:16
Show Gist options
  • Save maciejkorzepa/af3b5ef3676d982f8c4298f224960c53 to your computer and use it in GitHub Desktop.
Save maciejkorzepa/af3b5ef3676d982f8c4298f224960c53 to your computer and use it in GitHub Desktop.
Computation of empirical NTK kernel for a PyTorch model with Linear and Conv2d layers
#### code from torch.nn.grad modified so that the gradients are not summed over batch dimension ######
from torch._six import container_abcs
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
_pair = _ntuple(2)
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
r"""
Computes the gradient of conv2d with respect to the weight of the convolution.
Args:
input: input tensor of shape (minibatch x in_channels x iH x iW)
weight_size : Shape of the weight gradient tensor
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(1,1,3,3, requires_grad=True)
>>> weight = torch.randn(1,1,1,2, requires_grad=True)
>>> output = F.conv2d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
>>> F.grad.conv2d_weight(input, weight.shape, grad_output)
"""
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
in_channels = input.shape[1]
out_channels = grad_output.shape[1]
min_batch = input.shape[0]
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1,
1)
grad_output = grad_output.contiguous().view(
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
grad_output.shape[3])
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
input.shape[2], input.shape[3])
grad_weight = torch.conv2d(input, grad_output, None, dilation, padding,
stride, in_channels * min_batch)
grad_weight = grad_weight.contiguous().view(
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
grad_weight.shape[3])
return grad_weight.view(min_batch,
in_channels // groups, out_channels,
grad_weight.shape[2], grad_weight.shape[3]).transpose(1, 2).narrow(
3, 0, weight_size[2]).narrow(4, 0, weight_size[3])
###############
def forward_postprocess(module, input, output):
data_input = input[0].detach()
def backward_hook_linear(grad):
grad = grad.view(grad.shape[0], -1)
gg = grad @ grad.t()
module.kernel.addcmul_(1., module.A, gg)
del module.A
def backward_hook_conv(grad):
# weights kernel
with torch.no_grad():
J = conv2d_weight(data_input, module.weight.shape, grad, stride=module.stride,
padding=module.padding, dilation=module.dilation, groups=module.groups)
J = J.contiguous().view(J.shape[0], -1)
module.kernel.add_(J @ J.t())
# biases kernel
if hasattr(module, 'bias') and module.bias.requires_grad:
J = grad.sum(dim=[-2, -1])
module.kernel.add_(J @ J.t())
if output.requires_grad and isinstance(module, nn.Linear):
h = output.register_hook(backward_hook_linear)
data_input = data_input.view(data_input.shape[0], -1)
add = 1. if hasattr(module, 'bias') and module.bias.requires_grad else 0.
setattr(module, 'A', data_input @ data_input.t() + add)
elif output.requires_grad and isinstance(module, nn.Conv2d):
h = output.register_hook(backward_hook_conv)
def compute_train_kernel(model, x, output):
n = x.shape[0]
kernel = torch.zeros(n, n, device=x.device)
handles = []
for module in model.modules():
if len(list(module.children())) > 0:
continue
params = [p for p in module.parameters() if p.requires_grad]
if len(params) == 0:
continue
setattr(module, 'kernel', kernel)
if forward_postprocess is not None:
handles.append(module.register_forward_hook(forward_postprocess))
model.zero_grad()
f = model(x)
f[:, output].sum().backward()
for module in model.modules():
if hasattr(module, 'kernel'):
del module.kernel
for handle in handles:
handle.remove()
return kernel
import torch
import torch.nn as nn
import torch.nn.functional as F
num_base_chan = 32
class VGGish(nn.Module):
def __init__(self):
super(VGGish, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(3, num_base_chan, kernel_size=3, stride=2, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(num_base_chan, num_base_chan, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan, num_base_chan, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(num_base_chan, num_base_chan * 2, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 2, num_base_chan * 2, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(num_base_chan * 2, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 4, num_base_chan * 4, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(num_base_chan * 4, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(),
nn.Conv2d(num_base_chan * 8, num_base_chan * 8, kernel_size=3, stride=1, padding=1), nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(512, 1)
)
def forward(self, x):
x = self.convs(x)
x = torch.cat([
F.adaptive_avg_pool2d(x, output_size=1).view(x.shape[0], -1),
F.adaptive_max_pool2d(x, output_size=1).view(x.shape[0], -1)
], dim=1)
x = self.fc(x)
return x
d = 512
n = 80
inputs = torch.rand(n, 3, d, d).cuda()
model = VGGish().cuda()
kernel = compute_train_kernel(model, inputs, output=0)
# Naive implementation, Jacobians calculated one at a time using autograd to ensure correctness
from torch.nn.utils import parameters_to_vector
J = []
for x in inputs:
model.zero_grad()
f = model(x.unsqueeze(0))
f.backward()
J.append(torch.cat([m.grad.flatten() for m in model.parameters()]))
J = torch.stack(J)
# compute relative error using Frobenius norms
print(((kernel-kernel_naive)**2).sum().sqrt() / ((kernel_naive)**2).sum().sqrt())
# should print a value <1e-6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment