Skip to content

Instantly share code, notes, and snippets.

@HenKlei
Last active November 17, 2022 09:44
Show Gist options
  • Save HenKlei/4864dd08fcf960307104704d77c8c329 to your computer and use it in GitHub Desktop.
Save HenKlei/4864dd08fcf960307104704d77c8c329 to your computer and use it in GitHub Desktop.
Computation of derivatives of scalar valued functions in PyTorch (gradient, Hessian, Laplacian, ...)
import torch
# Returns the a tensor containing the gradient of `f` at the points `x`, i.e. ∇f(x[0]),...,∇f(x[n]).
# The function `f` is evaluated vectorized for each row of `x`.
# If provided, `y` has to contain the function evaluations of f at `x`.
# Shapes: x=(n, dim), y=(n), return=(n, dim)
def gradient(f, x, y=None):
if y is None:
y = f(x)
grads = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True, create_graph=True)[0]
return grads
# Returns a tensor containing ∇·(d(x[0])∇f(x[0])),...,∇·(d(x[n])∇f(x[n])).
# The function `f` is evaluated vectorized for each row of `x`.
# The function `d` is evaluated vectorized for each row of `x`.
# If provided, `grads` has to contain the gradient evaluations of `f` at `x`.
# If provided, `y` has to contain the function evaluations of `f` at `x` (not used if `grads` is provided).
# If provided, `d_coeffs` has to contain the function evaluations of `d` at `x`.
# Shapes: x=(n, dim), grads=(n, dim), y=(n), d_coeffs=(n), return=(n)
def divergence_diffusion_gradient(f, x, d, grads=None, y=None, d_coeffs=None):
if grads is None:
grads = gradient(f, x, y=y)
if d_coeffs is None:
d_coeffs = d(x)
diffusion_grads = d_coeffs[:, None]*grads
div_diffusion_grads = torch.zeros_like(d_coeffs)
for i in range(x.shape[1]):
e = torch.zeros_like(diffusion_grads)
e[..., i] = 1.
temp = torch.autograd.grad(diffusion_grads, x, e, retain_graph=True, create_graph=True)[0]
div_diffusion_grads += temp[..., i]
return div_diffusion_grads
# Returns a tensor containing the Laplacian of `f` at the points `x`, i.e. Δf(x[0]),...,Δf(x[n]).
# The function `f` is evaluated vectorized for each row of `x`.
# If provided, `grads` has to contain the gradient evaluations of `f` at `x`.
# If provided, `y` has to contain the function evaluations of `f` at `x` (not used if `grads` is provided).
# Shapes: x=(n, dim), grads=(n, dim), y=(n), return=(n)
def laplacian(f, x, grads=None, y=None):
if grads is None:
grads = gradient(f, x, y=y)
laplacian = torch.zeros(x.shape[0])
for i in range(x.shape[1]):
e = torch.zeros_like(grads)
e[..., i] = 1.
temp_hessian = torch.autograd.grad(grads, x, e, retain_graph=True, create_graph=True)[0]
laplacian += temp_hessian[..., i]
return laplacian
# Returns a tensor containing the Hessian of `f` at the points `x`, i.e. H_f(x[0]),...,H_f(x[n]).
# The function `f` is evaluated vectorized for each row of `x`.
# If provided, `grads` has to contain the gradient evaluations of `f` at `x`.
# If provided, `y` has to contain the function evaluations of `f` at `x` (not used if `grads` is provided).
# Shapes: x=(n, dim), grads=(n, dim), y=(n), return=(n, dim, dim)
def hessian(f, x, grads=None, y=None):
if grads is None:
grads = gradient(f, x, y=y)
hessian = torch.zeros((x.shape[0], x.shape[1], x.shape[1]))
for i in range(x.shape[1]):
e = torch.zeros_like(grads)
e[..., i] = 1.
temp = torch.autograd.grad(grads, x, e, retain_graph=True, create_graph=True)[0]
hessian[..., i] = temp
return hessian
# Example:
if __name__ == '__main__':
def f(x):
return x[..., 0] * x[..., 1] * x[..., 1] + x[..., 0] * x[..., 0] + 3.
n = 3
dim = 2
x = torch.arange(dim*n, requires_grad=True, dtype=torch.float).reshape((n, dim)) + 1.
print(f"Gradients: {gradient(f, x)}")
print(f"Hessians: {hessian(f, x)}")
print(f"Laplacians: {laplacian(f, x)}")
print()
print("Precomputing function values:")
y = f(x)
print(f"Gradients: {gradient(f, x, y=y)}")
print(f"Hessians: {hessian(f, x, y=y)}")
print(f"Laplacians: {laplacian(f, x, y=y)}")
print()
print("Precomputing gradients:")
gradients = gradient(f, x)
hessians = hessian(f, x, grads=gradients)
laplacians = laplacian(f, x, grads=gradients)
print(f"Gradients: {gradient(f, x)}")
print(f"Hessians: {hessian(f, x)}")
print(f"Laplacians: {laplacian(f, x)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment