/analyze_adahessian_update.py Secret
Last active
August 11, 2021 15:09
The core of AdaHessian analyzed on a 2D function: compute gradient and diagonal Hessian approximation (see function compute_grad_and_hessian(...)) and apply update step.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script allows analyzing the update step taken by AdaHessian. | |
Both the gradient g and the diagonal Hessian approximation H for a 2D loss function is computed. | |
Then, the update step in direction -g/H is applied. | |
Momentum is not implemented in this script. | |
As a comparison also a gradient descent update step is executed. | |
""" | |
from typing import Callable, Tuple, List, Optional | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
def plot_function(f: Callable, window: List[float]) -> None: | |
"Plot a 2D function by executing the PyTorch function for each grid point." | |
num_vals = 50 | |
x_vals = np.linspace(window[0], window[1], num_vals) | |
y_vals = np.linspace(window[2], window[3], num_vals) | |
X, Y = np.meshgrid(x_vals, y_vals) | |
Z = np.empty([num_vals, num_vals], dtype=np.float64) | |
for i, x in enumerate(x_vals): | |
for j, y in enumerate(y_vals): | |
v = torch.tensor((x, y), dtype=torch.float64) | |
Z[j, i] = f(v) | |
plt.pcolormesh(X, Y, Z, cmap='rainbow', shading='auto') | |
def plot_path(path: torch.tensor) -> None: | |
"Plot path taken by optimizer." | |
plt.plot(path[:, 0], path[:, 1], 'r-') | |
plt.plot(path[:, 0], path[:, 1], 'r.') | |
plt.plot(path[0, 0], path[0, 1], 'k^', label='start') | |
plt.plot(path[-1, 0], path[-1, 1], 'k*', label='last') | |
def sample_rademacher() -> int: | |
"""Flip a coin and return 1 for head and -1 otherwise.""" | |
return np.random.choice([1, -1]) | |
def compute_grad_and_hessian(f: Callable, | |
w: torch.tensor, | |
only_gradient: bool = False, | |
N: int = 10) -> Tuple[torch.tensor, Optional[torch.tensor]]: | |
"""Take function f(w) with parameters w and compute gradient and diagonal approximation of Hessian matrix. | |
Args: | |
f: Function to be minimized. | |
w: The parameter vector. | |
only_gradient: If True, then Hessian is not computed and returned as None. | |
N: The number of samples computed used in Hutchinson’s method. | |
Returns: | |
Tuple containing the gradient and the Hessian. | |
""" | |
# compute loss value | |
loss = f(w) | |
# compute gradient | |
w.grad = None | |
loss.backward(create_graph=True, retain_graph=True) | |
g = w.grad # gradient | |
# early exit if only gradient is needed | |
if only_gradient: | |
return g, None | |
# compute approximation of diagonal elements of Hessian with Hutchinson’s method | |
H_sum = torch.tensor([0, 0], dtype=torch.float) | |
for _ in range(N): | |
# random vector z containing only 1 and -1 elements | |
z = torch.tensor([sample_rademacher(), sample_rademacher()], dtype=torch.float) | |
# compute H*z using automatic differentiation | |
H_z = torch.autograd.grad(g @ z, w, retain_graph=True)[0] | |
H_sum += z * H_z # sum up | |
H = H_sum / N # mean value is approximation for diagonal Hessian | |
return g, H | |
def sgd(f: Callable, w: torch.tensor) -> torch.tensor: | |
"""Apply sgd update step.""" | |
# compute gradient | |
g, _ = compute_grad_and_hessian(f, w, only_gradient=True) | |
print('Gradient:') | |
print(g.data) | |
# apply update step -lr*g | |
with torch.no_grad(): | |
lr = 0.2 | |
w_new = w - lr * g | |
return w_new.detach() | |
def adahessian(f: Callable, w: torch.tensor) -> torch.tensor: | |
"""Apply adahessian update step.""" | |
# compute gradient and Hessian | |
g, H = compute_grad_and_hessian(f, w) | |
print('Gradient:') | |
print(g.data) | |
print('Hessian:') | |
print(torch.diag(H).data) | |
# compute update step -g/H | |
with torch.no_grad(): | |
w_new = w - g / H | |
return w_new.detach() | |
def main(): | |
# f(w) is the loss function to be minimized, w the parameters of the model | |
def f(w): | |
return w[0] ** 2 + 3 * w[1] ** 2 + w[0] * w[1] | |
# do a single step both with sgd and with adahessian | |
plt.figure(figsize=[9, 4]) | |
for i, method in enumerate([sgd, adahessian]): | |
# apply current method | |
print(f'===Method: {method.__name__}===') | |
w = torch.tensor([7, -5], dtype=torch.float, requires_grad=True) | |
w_new = method(f, w) | |
# compute old and new value to see if new function value decreased | |
loss_old = f(w) | |
loss_new = f(w_new) | |
print(f'Took step from {w.data} to {w_new.data}') | |
print(f'Improved loss value from {loss_old} to {loss_new}') | |
print() | |
# plot step of current method | |
plt.subplot(1, 2, i + 1) | |
plt.title(method.__name__) | |
ext = 10 | |
plot_function(f, [-ext, ext, -ext, ext]) | |
plot_path(torch.stack([w.detach(), w_new])) | |
plt.plot([0], [0], 'w*', label='minimum') | |
plt.legend() | |
plt.colorbar() | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment