Skip to content

Instantly share code, notes, and snippets.

View danielkelshaw's full-sized avatar
:octocat:
Approximating the Posterior

Daniel Kelshaw danielkelshaw

:octocat:
Approximating the Posterior
View GitHub Profile
@RicardoDominguez
RicardoDominguez / jax_bvp_solver.py
Last active June 11, 2024 17:41
BVP solver in JAX based on scipy.integrate.solve_bvp
"""Boundary value problem solver."""
import jax
import jax.numpy as jnp
# ------------------------------------------------------------------------------------------
# Linear solver for bordered almost block diagonal (BABD) systems
# ------------------------------------------------------------------------------------------
# Implementation as described in [1] Section 2.1 (structural orthogonal factorization).
import torch
def jacobian(y, x, create_graph=False):
jac = []
flat_y = y.reshape(-1)
grad_y = torch.zeros_like(flat_y)
for i in range(len(flat_y)):
grad_y[i] = 1.
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
jac.append(grad_x.reshape(x.shape))