Last active
March 1, 2024 13:57
-
-
Save b-fg/d637e6f3e7bb190dce5edcd20e57e7f0 to your computer and use it in GitHub Desktop.
Example of automatic differentiation (AD) using PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Testing PyTorch automatic differentiation (AD) tools for an analytical multivariate polynomial function: | |
u_i(x_i) = u_i(x,y,...) = (u(x,y,...), v(x,y,...)) | |
In particular we test the function u_i(x,y):=( | |
u(x,y) := x^5 + 4x^2y + 6y^2 + 2, | |
v(x,y) := xy^4 + 4xy^2 + 6x^2 - y | |
) | |
""" | |
import torch | |
n = 10 # number of samples (batch dimension) | |
domain_limits = ((-1, 1), (-1, 1)) | |
x_i = torch.stack([ | |
torch.DoubleTensor(n).uniform_(*domain_limits[0]), | |
torch.DoubleTensor(n).uniform_(*domain_limits[1]), | |
], axis=-1) | |
def u_i(x_i): | |
""" | |
Analytical polynomial vector function | |
""" | |
# Fix x_i dimensions to match (batch dimension, variable dimension), even if we have a single sample in the batch dimension | |
x_i = x_i[None] if len(x_i.shape) < 2 else x_i | |
x_i = x_i.squeeze()[None] if len(x_i.shape) > 2 else x_i | |
x, y = torch.hsplit(x_i, 2) | |
return torch.hstack([x**5 + 4*x**2*y + 6*y**2 + 2, x*y**4 + 4*x*y**2 + 6*x**2 - y]) | |
# Analytical derivatives | |
def dudx(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (5*x**4 + 8*x*y).squeeze() | |
def dudy(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (4*x**2 + 12*y).squeeze() | |
def dvdx(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (y**4 + 4*y**2 + 12*x).squeeze() | |
def dvdy(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (4*x*y**3 + 8*x*y - 1.0).squeeze() | |
def d2udx2(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (20*x**3 + 8*y).squeeze() | |
def d2vdy2(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (12*x*y**2 + 8*x).squeeze() | |
def d3udx3(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (60*x**2).squeeze() | |
def d3vdy3(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (24*x*y).squeeze() | |
def d3udx2y(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return torch.tensor(8.0, dtype=torch.float64) | |
def d4udx4(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (120*x).squeeze() | |
def d4vdy4(x_i): | |
x, y = torch.hsplit(x_i, 2) | |
return (24*x).squeeze() | |
# Automatic derivatives | |
def J(x_i): | |
""" | |
1st-order derivatives | |
J(x_i).shape = (batch_size, len(Y), len(X)). J[...,0,0]==dudx, J[...,1,1]==dvdy | |
""" | |
x_i = x_i[None] if len(x_i.shape) < 2 else x_i | |
x_i = x_i.squeeze() if len(x_i.shape) > 2 else x_i | |
return torch.vmap(torch.func.jacrev(u_i))(x_i).squeeze() | |
def JJ(x_i): | |
""" | |
2nd-order derivatives | |
JJ(x_i).shape = (batch_size, len(Y), len(X), len(X)). Returns exactly H(x_i) | |
""" | |
return torch.vmap(torch.func.jacrev(J))(x_i).squeeze() | |
def H(x_i): | |
""" | |
2nd-order derivatives | |
H(x_i).shape = (batch_size, len(Y), len(X), len(X)). H[...,0,0,0]==d2udx2, H[...,1,0,1]==d2vdxy | |
""" | |
x_i = x_i[None] if len(x_i.shape) < 2 else x_i | |
x_i = x_i.squeeze() if len(x_i.shape) > 2 else x_i | |
return torch.vmap(torch.func.hessian(u_i))(x_i).squeeze() | |
def JH(x_i): | |
""" | |
3rd-order derivatives | |
JH(x_i).shape = (batch_size, len(Y), len(X), len(X), len(X)). H[...,0,0,0,0]==d4udx4, H[...,1,0,1,0,1]==d4vdx2y2 | |
""" | |
return torch.vmap(torch.func.jacrev(H))(x_i).squeeze() | |
def HH(x_i): | |
""" | |
4th-order derivatives | |
HH(x_i).shape = (batch_size, len(Y), len(X), len(X), len(X), len(X)). HH[...,0,0,0,0,0]==d4udx4, HH[...,1,0,1,0,1]==d4vdx2y2 | |
""" | |
return torch.vmap(torch.func.hessian(H))(x_i).squeeze() | |
# Tests | |
print('JJ == H \t\t\t pass') if torch.allclose(JJ(x_i), H(x_i)) else print('J==H \t\t\t error!') | |
print('dudx == J[...,0,0] \t\t pass') if torch.allclose(dudx(x_i), J(x_i)[...,0,0]) else print('dudx == J[...,0,0] \t\t error!') | |
print('dvdy == J[...,1,1] \t\t pass') if torch.allclose(dvdy(x_i), J(x_i)[...,1,1]) else print('dvdx == J[...,1,1] \t\t error!') | |
print('d2udx2 == H[...,0,0,0] \t\t pass') if torch.allclose(d2udx2(x_i), H(x_i)[...,0,0,0]) else print('d2udx2 == H[...,0,0,0] \t\t error!') | |
print('d2vdy2 == H[...,1,1,1] \t\t pass') if torch.allclose(d2vdy2(x_i), H(x_i)[...,1,1,1]) else print('d2vdy2 == H[...,1,1,1] \t\t error!') | |
print('d3udx3 == JH[...,0,0,0,0] \t pass') if torch.allclose(d3udx3(x_i), JH(x_i)[...,0,0,0,0]) else print('d3udx3 == JH[...,0,0,0,0] \t error!') | |
print('d3vdy3 == JH[...,1,1,1,1] \t pass') if torch.allclose(d3vdy3(x_i), JH(x_i)[...,1,1,1,1]) else print('d3vdy3 == JH[...,1,1,1,1] \t error!') | |
print('d3udx2y == JH[...,0,0,0,1] \t pass') if torch.allclose(d3udx2y(x_i), JH(x_i)[...,0,0,0,1]) else print('d3udx2y == JH[...,0,0,0,1] \t error!') | |
print('d4udx4 == HH[...,0,0,0,0,0] \t pass') if torch.allclose(d4udx4(x_i), HH(x_i)[...,0,0,0,0,0]) else print('d4udx4 == HH[...,0,0,0,0,0] \t error!') | |
print('d4vdy4 == HH[...,1,1,1,1,1] \t pass') if torch.allclose(d4vdy4(x_i), HH(x_i)[...,1,1,1,1,1]) else print('d4vdy4 == HH[...,1,1,1,1,1] \t error!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment