Skip to content

Instantly share code, notes, and snippets.

@madagra
Last active December 25, 2022 11:14
Show Gist options
  • Save madagra/d0761073283c9511b5a8731b455ebc8b to your computer and use it in GitHub Desktop.
Save madagra/d0761073283c9511b5a8731b455ebc8b to your computer and use it in GitHub Desktop.
from functorch import make_functional, grad, vmap
# create the PINN model and make it functional using functorch utilities
model = NNApproximator()
fmodel, params = make_functional(model)
def f(x: torch.Tensor, params: torch.Tensor) -> torch.Tensor:
# only a single element is supported thus unsqueeze must be applied
# for batching multiple inputs, `vmap` must be used as below
x_ = x.unsqueeze(0)
res = fmodel(params, x_).squeeze(0)
return res
# use `vmap` primitive to allow efficient batching of the input
f_vmap = vmap(f, in_dims=(0, None))
# return function for computing higher order gradients with respect
# to input by simply composing `grad` calls and use again `vmap` for
# efficient batching of the input
dfdx = vmap(grad(f), in_dims=(0, None))
d2fdx2 = vmap(grad(grad(f)), in_dims=(0, None))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment