Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Created August 25, 2019 12:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save yaroslavvb/6b08eb8f683b646785d39e97679fdb2e to your computer and use it in GitHub Desktop.
Save yaroslavvb/6b08eb8f683b646785d39e97679fdb2e to your computer and use it in GitHub Desktop.
Example of computing Hessian of linear layer
def test():
u.seed_random(1)
data_width = 3
targets_width = 2
batch_size = 3
dataset = TinyMNIST('/tmp', download=True, data_width=data_width, targets_width=targets_width, dataset_size=batch_size)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
d1 = data_width ** 2 # hidden layer size, visible size, output size
d2 = targets_width ** 2 # hidden layer size, visible size, output size
n = batch_size
model = Net([d1, d2])
layer = model.layers[0]
W = model.layers[0].weight
skip_hooks = False
def capture_activations(module, input, _output):
if skip_hooks:
return
assert not hasattr(module, 'activations'), "Seeing results of previous autograd, call util.zero_grad to clear"
assert len(input) == 1, "this works for single input layers only"
setattr(module, "activations", input[0].detach())
def capture_backprops(module: nn.Module, _input, output):
if skip_hooks:
return
assert not hasattr(module, 'backprops'), "Seeing results of previous autograd, call util.zero_grad to clear"
assert len(output) == 1, "this works for single variable layers only"
setattr(module, "backprops", output[0])
layer.register_forward_hook(capture_activations)
layer.register_backward_hook(capture_backprops)
def loss_fn(data, targets):
err = data - targets.view(-1, data.shape[1])
assert len(data) == batch_size
return torch.sum(err * err) / 2 / len(data)
# def unvec(x): return u.unvec(x, d)
# Gradient
data, targets = next(iter(trainloader))
loss = loss_fn(model(data), targets)
loss.backward()
A = layer.activations.t()
assert A.shape == (d1, n)
# add factor of n here because backprop computes loss averaged over batch, while we need per-example loss backprop
B = layer.backprops.t() * n
assert B.shape == (d2, n)
u.check_close(B @ A.t() / n, W.grad)
# Hessian
skip_hooks = True
loss = loss_fn(model(data), targets)
H = u.hessian(loss, W)
H = H.transpose(0, 1).transpose(2, 3).reshape(d1 * d2, d1 * d2)
print(H)
# compute B matrices
Bs_t = [] # one matrix per class, storing backprops for current layer
skip_hooks = False
id_mat = torch.eye(d2)
for out_idx in range(d2):
u.zero_grad(model)
output = model(data)
_loss = loss_fn(output, targets)
ei = id_mat[out_idx]
bval = torch.stack([ei]*batch_size)
output.backward(bval)
Bs_t.append(layer.backprops)
A_t = layer.activations
# batch output Jacobian, each row corresponds to example,output pair
Amat = torch.cat([A_t]*d2, dim=0)
Bmat = torch.cat(Bs_t, dim=0)
Jb = u.khatri_rao_t(Amat, Bmat)
H2 = Jb.t() @ Jb / n
u.check_close(H, H2)
-- utils
def khatri_rao(A: torch.Tensor, B: torch.Tensor):
"""Khatri-Rao product.
i'th column of result C_i is a Kronecker product of A_i and B_i
Section 2.6 of Kolda, Tamara G., and Brett W. Bader. "Tensor decompositions and applications." SIAM review 51.3
(2009): 455-500"""
assert A.shape[1] == B.shape[1]
# noinspection PyTypeChecker
return torch.einsum("ik,jk->ijk", A, B).reshape(A.shape[0] * B.shape[0], A.shape[1])
def test_khatri_rao():
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
C = torch.tensor([[5, 12], [7, 16],
[15, 24], [21, 32]])
check_equal(khatri_rao(A, B), C)
def khatri_rao_t(A: torch.Tensor, B: torch.Tensor):
"""Transposed Khatri-Rao, inputs and outputs are transposed."""
assert A.shape[0] == B.shape[0]
# noinspection PyTypeChecker
return torch.einsum("ki,kj->kij", A, B).reshape(A.shape[0], A.shape[1] * B.shape[1])
def jacobian(y: torch.Tensor, x: torch.Tensor, create_graph=False):
jac = []
flat_y = y.reshape(-1)
grad_y = torch.zeros_like(flat_y)
for i in range(len(flat_y)):
grad_y[i] = 1.
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
jac.append(grad_x.reshape(x.shape))
grad_y[i] = 0.
return torch.stack(jac).reshape(y.shape + x.shape)
def hessian(y: torch.Tensor, x: torch.Tensor):
return jacobian(jacobian(y, x, create_graph=True), x)
class TinyMNIST(datasets.MNIST):
"""Dataset for autoencoder task."""
# 60k,1,new_dim,new_dim
def __init__(self, root, data_width=4, targets_width=4, dataset_size=60000, download=True):
super().__init__(root, download)
# Put both data and targets on GPU in advance
self.data = self.data[:dataset_size, :, :]
new_data = np.zeros((self.data.shape[0], data_width, data_width))
new_targets = np.zeros((self.data.shape[0], targets_width, targets_width))
for i in range(self.data.shape[0]):
arr = self.data[i, :].numpy().astype(np.uint8)
im = Image.fromarray(arr)
im.thumbnail((data_width, data_width), Image.ANTIALIAS)
new_data[i, :, :] = np.array(im) / 255
im = Image.fromarray(arr)
im.thumbnail((targets_width, targets_width), Image.ANTIALIAS)
new_targets[i, :, :] = np.array(im) / 255
self.data = torch.from_numpy(new_data).float()
self.data = self.data.unsqueeze(1)
self.targets = torch.from_numpy(new_targets).float()
self.targets = self.targets.unsqueeze(1)
self.data, self.targets = self.data.to(device), self.targets.to(device)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
return img, target
class Net(nn.Module):
def __init__(self, d: List[int], nonlin=False):
super().__init__()
self.layers: List[nn.Module] = []
self.all_layers: List[nn.Module] = []
self.d: List[int] = d
for i in range(len(d) - 1):
linear = nn.Linear(d[i], d[i + 1], bias=False)
setattr(linear, 'name', f'{i:02d}-linear')
self.layers.append(linear)
self.all_layers.append(linear)
if nonlin:
self.all_layers.append(nn.ReLU())
self.predict = torch.nn.Sequential(*self.all_layers)
def forward(self, x: torch.Tensor):
x = x.reshape((-1, self.d[0]))
return self.predict(x)
def seed_random(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment