Skip to content

Instantly share code, notes, and snippets.

@UlisseMini
Last active January 9, 2023 20:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save UlisseMini/e7c55ce90bcb8316bbd4718e7182a3df to your computer and use it in GitHub Desktop.
Save UlisseMini/e7c55ce90bcb8316bbd4718e7182a3df to your computer and use it in GitHub Desktop.
How to obtain the hessian of an mnist neural net as a flat matrix
# fuller code https://colab.research.google.com/drive/12zXLbykv537MrZr6WDCnRIqKQ5h8UjVw?usp=sharing
fn = lambda *params: F.nll_loss(stateless.functional_call(model, {n: p for n,p in zip(names, params)}, x), y)
H = hessian(fn, tuple(model.parameters()))
# H[i][j] contains the derivatives of the loss with respect to every parameter in model.parameters()[i] and [j].
# (It's an annoying tuple)
# flatten the annoying tuple!
rows = []
shapes = [p.shape for p in model.parameters()]
for i in range(len(H)):
rows.append(torch.cat([H[j][i].view(shapes[j].numel(), shapes[i].numel()) for j in range(len(H))], dim=0))
full_hessian = torch.cat(rows, dim=1)
full_hessian.shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment