Skip to content

Instantly share code, notes, and snippets.

@githubharald
Last active August 11, 2021 15:09
The core of AdaHessian analyzed on a 2D function: compute gradient and diagonal Hessian approximation (see function compute_grad_and_hessian(...)) and apply update step.
"""
This script allows analyzing the update step taken by AdaHessian.
Both the gradient g and the diagonal Hessian approximation H for a 2D loss function is computed.
Then, the update step in direction -g/H is applied.
Momentum is not implemented in this script.
As a comparison also a gradient descent update step is executed.
"""
from typing import Callable, Tuple, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_function(f: Callable, window: List[float]) -> None:
"Plot a 2D function by executing the PyTorch function for each grid point."
num_vals = 50
x_vals = np.linspace(window[0], window[1], num_vals)
y_vals = np.linspace(window[2], window[3], num_vals)
X, Y = np.meshgrid(x_vals, y_vals)
Z = np.empty([num_vals, num_vals], dtype=np.float64)
for i, x in enumerate(x_vals):
for j, y in enumerate(y_vals):
v = torch.tensor((x, y), dtype=torch.float64)
Z[j, i] = f(v)
plt.pcolormesh(X, Y, Z, cmap='rainbow', shading='auto')
def plot_path(path: torch.tensor) -> None:
"Plot path taken by optimizer."
plt.plot(path[:, 0], path[:, 1], 'r-')
plt.plot(path[:, 0], path[:, 1], 'r.')
plt.plot(path[0, 0], path[0, 1], 'k^', label='start')
plt.plot(path[-1, 0], path[-1, 1], 'k*', label='last')
def sample_rademacher() -> int:
"""Flip a coin and return 1 for head and -1 otherwise."""
return np.random.choice([1, -1])
def compute_grad_and_hessian(f: Callable,
w: torch.tensor,
only_gradient: bool = False,
N: int = 10) -> Tuple[torch.tensor, Optional[torch.tensor]]:
"""Take function f(w) with parameters w and compute gradient and diagonal approximation of Hessian matrix.
Args:
f: Function to be minimized.
w: The parameter vector.
only_gradient: If True, then Hessian is not computed and returned as None.
N: The number of samples computed used in Hutchinson’s method.
Returns:
Tuple containing the gradient and the Hessian.
"""
# compute loss value
loss = f(w)
# compute gradient
w.grad = None
loss.backward(create_graph=True, retain_graph=True)
g = w.grad # gradient
# early exit if only gradient is needed
if only_gradient:
return g, None
# compute approximation of diagonal elements of Hessian with Hutchinson’s method
H_sum = torch.tensor([0, 0], dtype=torch.float)
for _ in range(N):
# random vector z containing only 1 and -1 elements
z = torch.tensor([sample_rademacher(), sample_rademacher()], dtype=torch.float)
# compute H*z using automatic differentiation
H_z = torch.autograd.grad(g @ z, w, retain_graph=True)[0]
H_sum += z * H_z # sum up
H = H_sum / N # mean value is approximation for diagonal Hessian
return g, H
def sgd(f: Callable, w: torch.tensor) -> torch.tensor:
"""Apply sgd update step."""
# compute gradient
g, _ = compute_grad_and_hessian(f, w, only_gradient=True)
print('Gradient:')
print(g.data)
# apply update step -lr*g
with torch.no_grad():
lr = 0.2
w_new = w - lr * g
return w_new.detach()
def adahessian(f: Callable, w: torch.tensor) -> torch.tensor:
"""Apply adahessian update step."""
# compute gradient and Hessian
g, H = compute_grad_and_hessian(f, w)
print('Gradient:')
print(g.data)
print('Hessian:')
print(torch.diag(H).data)
# compute update step -g/H
with torch.no_grad():
w_new = w - g / H
return w_new.detach()
def main():
# f(w) is the loss function to be minimized, w the parameters of the model
def f(w):
return w[0] ** 2 + 3 * w[1] ** 2 + w[0] * w[1]
# do a single step both with sgd and with adahessian
plt.figure(figsize=[9, 4])
for i, method in enumerate([sgd, adahessian]):
# apply current method
print(f'===Method: {method.__name__}===')
w = torch.tensor([7, -5], dtype=torch.float, requires_grad=True)
w_new = method(f, w)
# compute old and new value to see if new function value decreased
loss_old = f(w)
loss_new = f(w_new)
print(f'Took step from {w.data} to {w_new.data}')
print(f'Improved loss value from {loss_old} to {loss_new}')
print()
# plot step of current method
plt.subplot(1, 2, i + 1)
plt.title(method.__name__)
ext = 10
plot_function(f, [-ext, ext, -ext, ext])
plot_path(torch.stack([w.detach(), w_new]))
plt.plot([0], [0], 'w*', label='minimum')
plt.legend()
plt.colorbar()
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment