Skip to content

Instantly share code, notes, and snippets.

@b-fg
Last active March 1, 2024 13:57
Show Gist options
  • Save b-fg/d637e6f3e7bb190dce5edcd20e57e7f0 to your computer and use it in GitHub Desktop.
Save b-fg/d637e6f3e7bb190dce5edcd20e57e7f0 to your computer and use it in GitHub Desktop.
Example of automatic differentiation (AD) using PyTorch
"""
Testing PyTorch automatic differentiation (AD) tools for an analytical multivariate polynomial function:
u_i(x_i) = u_i(x,y,...) = (u(x,y,...), v(x,y,...))
In particular we test the function u_i(x,y):=(
u(x,y) := x^5 + 4x^2y + 6y^2 + 2,
v(x,y) := xy^4 + 4xy^2 + 6x^2 - y
)
"""
import torch
n = 10 # number of samples (batch dimension)
domain_limits = ((-1, 1), (-1, 1))
x_i = torch.stack([
torch.DoubleTensor(n).uniform_(*domain_limits[0]),
torch.DoubleTensor(n).uniform_(*domain_limits[1]),
], axis=-1)
def u_i(x_i):
"""
Analytical polynomial vector function
"""
# Fix x_i dimensions to match (batch dimension, variable dimension), even if we have a single sample in the batch dimension
x_i = x_i[None] if len(x_i.shape) < 2 else x_i
x_i = x_i.squeeze()[None] if len(x_i.shape) > 2 else x_i
x, y = torch.hsplit(x_i, 2)
return torch.hstack([x**5 + 4*x**2*y + 6*y**2 + 2, x*y**4 + 4*x*y**2 + 6*x**2 - y])
# Analytical derivatives
def dudx(x_i):
x, y = torch.hsplit(x_i, 2)
return (5*x**4 + 8*x*y).squeeze()
def dudy(x_i):
x, y = torch.hsplit(x_i, 2)
return (4*x**2 + 12*y).squeeze()
def dvdx(x_i):
x, y = torch.hsplit(x_i, 2)
return (y**4 + 4*y**2 + 12*x).squeeze()
def dvdy(x_i):
x, y = torch.hsplit(x_i, 2)
return (4*x*y**3 + 8*x*y - 1.0).squeeze()
def d2udx2(x_i):
x, y = torch.hsplit(x_i, 2)
return (20*x**3 + 8*y).squeeze()
def d2vdy2(x_i):
x, y = torch.hsplit(x_i, 2)
return (12*x*y**2 + 8*x).squeeze()
def d3udx3(x_i):
x, y = torch.hsplit(x_i, 2)
return (60*x**2).squeeze()
def d3vdy3(x_i):
x, y = torch.hsplit(x_i, 2)
return (24*x*y).squeeze()
def d3udx2y(x_i):
x, y = torch.hsplit(x_i, 2)
return torch.tensor(8.0, dtype=torch.float64)
def d4udx4(x_i):
x, y = torch.hsplit(x_i, 2)
return (120*x).squeeze()
def d4vdy4(x_i):
x, y = torch.hsplit(x_i, 2)
return (24*x).squeeze()
# Automatic derivatives
def J(x_i):
"""
1st-order derivatives
J(x_i).shape = (batch_size, len(Y), len(X)). J[...,0,0]==dudx, J[...,1,1]==dvdy
"""
x_i = x_i[None] if len(x_i.shape) < 2 else x_i
x_i = x_i.squeeze() if len(x_i.shape) > 2 else x_i
return torch.vmap(torch.func.jacrev(u_i))(x_i).squeeze()
def JJ(x_i):
"""
2nd-order derivatives
JJ(x_i).shape = (batch_size, len(Y), len(X), len(X)). Returns exactly H(x_i)
"""
return torch.vmap(torch.func.jacrev(J))(x_i).squeeze()
def H(x_i):
"""
2nd-order derivatives
H(x_i).shape = (batch_size, len(Y), len(X), len(X)). H[...,0,0,0]==d2udx2, H[...,1,0,1]==d2vdxy
"""
x_i = x_i[None] if len(x_i.shape) < 2 else x_i
x_i = x_i.squeeze() if len(x_i.shape) > 2 else x_i
return torch.vmap(torch.func.hessian(u_i))(x_i).squeeze()
def JH(x_i):
"""
3rd-order derivatives
JH(x_i).shape = (batch_size, len(Y), len(X), len(X), len(X)). H[...,0,0,0,0]==d4udx4, H[...,1,0,1,0,1]==d4vdx2y2
"""
return torch.vmap(torch.func.jacrev(H))(x_i).squeeze()
def HH(x_i):
"""
4th-order derivatives
HH(x_i).shape = (batch_size, len(Y), len(X), len(X), len(X), len(X)). HH[...,0,0,0,0,0]==d4udx4, HH[...,1,0,1,0,1]==d4vdx2y2
"""
return torch.vmap(torch.func.hessian(H))(x_i).squeeze()
# Tests
print('JJ == H \t\t\t pass') if torch.allclose(JJ(x_i), H(x_i)) else print('J==H \t\t\t error!')
print('dudx == J[...,0,0] \t\t pass') if torch.allclose(dudx(x_i), J(x_i)[...,0,0]) else print('dudx == J[...,0,0] \t\t error!')
print('dvdy == J[...,1,1] \t\t pass') if torch.allclose(dvdy(x_i), J(x_i)[...,1,1]) else print('dvdx == J[...,1,1] \t\t error!')
print('d2udx2 == H[...,0,0,0] \t\t pass') if torch.allclose(d2udx2(x_i), H(x_i)[...,0,0,0]) else print('d2udx2 == H[...,0,0,0] \t\t error!')
print('d2vdy2 == H[...,1,1,1] \t\t pass') if torch.allclose(d2vdy2(x_i), H(x_i)[...,1,1,1]) else print('d2vdy2 == H[...,1,1,1] \t\t error!')
print('d3udx3 == JH[...,0,0,0,0] \t pass') if torch.allclose(d3udx3(x_i), JH(x_i)[...,0,0,0,0]) else print('d3udx3 == JH[...,0,0,0,0] \t error!')
print('d3vdy3 == JH[...,1,1,1,1] \t pass') if torch.allclose(d3vdy3(x_i), JH(x_i)[...,1,1,1,1]) else print('d3vdy3 == JH[...,1,1,1,1] \t error!')
print('d3udx2y == JH[...,0,0,0,1] \t pass') if torch.allclose(d3udx2y(x_i), JH(x_i)[...,0,0,0,1]) else print('d3udx2y == JH[...,0,0,0,1] \t error!')
print('d4udx4 == HH[...,0,0,0,0,0] \t pass') if torch.allclose(d4udx4(x_i), HH(x_i)[...,0,0,0,0,0]) else print('d4udx4 == HH[...,0,0,0,0,0] \t error!')
print('d4vdy4 == HH[...,1,1,1,1,1] \t pass') if torch.allclose(d4vdy4(x_i), HH(x_i)[...,1,1,1,1,1]) else print('d4vdy4 == HH[...,1,1,1,1,1] \t error!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment