Skip to content

Instantly share code, notes, and snippets.

@nicholascc
Last active June 6, 2023 22:27
Show Gist options
  • Save nicholascc/c42237bbace2cedf250cc0708177e0ea to your computer and use it in GitHub Desktop.
Save nicholascc/c42237bbace2cedf250cc0708177e0ea to your computer and use it in GitHub Desktop.
Hessian calculator
# %%
import torch
import torchvision
from torch import nn, optim, autograd
from torchvision import transforms
import numpy
import scipy.special
import matplotlib.pyplot
import math
# %%
batch_size = 64
num_classes = 10
# %%
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)),
torch.flatten
])
# %%
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
print(mnist_train[0])
# %%
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28*28, num_classes)
def forward(self, x):
return (self.l1(x)).softmax(dim=-1)
# %%
net = Net()
optimizer = optim.Adam(net.parameters())
# %%
def get_loss_params_pred(params,pred,labels,weight_decay_coefficient):
labels_encoded = nn.functional.one_hot(labels, num_classes).float()
l2_loss = nn.functional.mse_loss(pred, labels_encoded)
l2_regularization_penalty = (weight_decay_coefficient * (params ** 2)).sum()
return l2_loss + l2_regularization_penalty
def get_loss(model, images, labels, weight_decay_coefficient=1e-5):
pred = model(images)
params = nn.utils.parameters_to_vector(model.parameters())
return get_loss_params_pred(params,pred,labels,weight_decay_coefficient)
# %%
num_epochs = 3
net.train()
for epoch in range(num_epochs):
for images, labels in train_dataloader:
optimizer.zero_grad()
loss = get_loss(net, images, labels)
loss.backward()
optimizer.step()
print(f"Finished epoch {epoch+1}.")
# %%
def evaluate_model(net):
net.eval()
with torch.inference_mode():
running_loss = 0
for img, label in mnist_test:
loss = get_loss(net, img, torch.tensor(label), weight_decay_coefficient=0)
running_loss += loss
return running_loss / len(mnist_test)
evaluate_model(net)
# -> tensor(0.0125)
# %%
# CURRENT APPROACH
def loss_from_params(images, labels, model):
param_names = list(n for n, _ in model.named_parameters())
def loss(*params):
param_dict = {n: p for n, p in zip(param_names, params)}
pred = torch.func.functional_call(model, param_dict, images)
params_vector = nn.utils.parameters_to_vector(params)
return get_loss_params_pred(params_vector,pred,labels,1e-5)
return loss
def v_n_ball(n: int) -> float:
#if n == 0:
# return 1
#if n == 1:
# return 2
#return 2*numpy.pi/n * v_n_ball(n-2)
return numpy.pi**(n/2) / scipy.special.gamma(n/2 + 1)
def calc_basin(model:torch.nn.Module, loss_threshold: float, n: int) -> int:
param_count = sum([numpy.prod(p.shape) for p in model.parameters()])
assert(n <= param_count) # We'll only be looking at the Hessian over the first n parameters
v_n = v_n_ball(n)
numerator = v_n * (2*loss_threshold)**n
images = torch.unsqueeze(mnist_test[0][0],0)
labels = torch.unsqueeze(torch.tensor(mnist_test[0][1]),0)
loss_fn = loss_from_params(images, labels, model)
hessian = autograd.functional.hessian(loss_fn, tuple(model.parameters()))
hessian = torch.cat([torch.cat([e.flatten() for e in part]) for part in hessian]) # flatten
hessian = hessian.reshape(param_count, param_count)
print("Hessian: ", hessian)
print("det(Hessian): ", torch.det(hessian))
subhessian = hessian[:n, :n]
print("Sub-Hessian: ", subhessian)
print("det(Sub-Hessian): ", torch.det(subhessian))
σ = 1 / math.sqrt(28*28) # Standard deviation of the initialization Gaussian
# See https://github.com/pytorch/pytorch/blob/6408b85d88cf2d3790ca8fbf8a73201fe0d24d3e/torch/nn/modules/linear.py#LL103C1-L103C1
k = 1 # "For a crude model, k = 1 is probably good enough"
c = k/(σ**2) # TODO: Not super confident about these lines
λ = 1e-5 # weight_decay_coefficient
total = subhessian + (c + λ)*torch.eye(n)
det = torch.det(total)
print("det(total): ", det)
denominator = det**(1/2)
print(numerator, denominator)
v_basin = numerator / denominator
return v_basin
calc_basin(net, 0.1,10) # THIS WORKS
# %%
# old approach
def v_n_ball(n: int) -> float:
if n == 0:
return 1
if n == 1:
return 2
return 2*numpy.pi/n * v_n_ball(n-2)
#return numpy.pi**(n/2) / scipy.special.gamma(n/2 + 1)
def calc_basin(model:torch.nn.Module, loss_threshold: float) -> int:
params = model.parameters()
n = sum([numpy.prod(p.shape) for p in params]) # parameter count
v_n = v_n_ball(n)
numerator = v_n * (2*loss_threshold)**n
loss_fn = get_loss
hessian = autograd.functional.Hessian(loss_fn, )
σ = 1 / math.sqrt(28*28) # Standard deviation of the initialization Gaussian
# See https://github.com/pytorch/pytorch/blob/6408b85d88cf2d3790ca8fbf8a73201fe0d24d3e/torch/nn/modules/linear.py#LL103C1-L103C1
k = 1 # "For a crude model, k = 1 is probably good enough"
c = k/(σ**2) # TODO: Not super confident about these lines
λ = 1e-5 # weight_decay_coefficient
total = hessian + (λ + c)*torch.eye(n)
det = torch.det(total)
denominator = det**(1/2)
v_basin = numerator / denominator
return v_basin
calc_basin(net, 0.01)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment