Last active
September 29, 2023 12:00
-
-
Save tomelse/f9ba7508b75f44f34ebdbf25d5f5b0a3 to your computer and use it in GitHub Desktop.
J-Wave band-limited interpolant for off-grid sensors (http://github.com/ucl-bug/jwave). This is based on this paper : http://bug.medphys.ucl.ac.uk/papers/2019-Wise-JASA.pdf. It is similar to what has been implemented in k-Wave (www.k-wave.org/forum/topic/alpha-version-of-kwavearray-off-grid-sources).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
import numpy.typing | |
from jax.tree_util import register_pytree_node_class | |
import numpy as np | |
from typing import Tuple | |
from jwave.geometry import Field | |
@register_pytree_node_class | |
class BLISensors: | |
r"""Sensors structure | |
Attributes: | |
positions (Tuple[List[int]]): sensors positions | |
""" | |
positions: Tuple[np.typing.ArrayLike] | |
def __init__(self, positions, N): | |
self.positions = positions | |
self.N = N | |
x = np.arange(N[0])[None] | |
bx = (1 / N[0]) * (np.sin(np.pi * (positions[0][:, None] - x)) / np.tan( | |
np.pi * (positions[0][:, None] - x) / N[0]) - np.sin(np.pi * positions[0][:, None]) * np.sin( | |
np.pi * x))[:, :, None, None, None] | |
self.bx = jnp.array(bx) | |
y = np.arange(N[1])[None] | |
by = (1 / N[1]) * (np.sin(np.pi * (positions[1][:, None] - y)) / np.tan( | |
np.pi * (positions[1][:, None] - y) / N[1]) - np.sin(np.pi * positions[1][:, None]) * np.sin( | |
np.pi * y))[:, :, None, None] | |
self.by = jnp.array(by) | |
z = np.arange(N[2])[None] | |
bz = (1 / N[2]) * (np.sin(np.pi * (positions[2][:, None] - z)) / np.tan( | |
np.pi * (positions[2][:, None] - z) / N[2]) - np.sin(np.pi * positions[2][:, None]) * np.sin( | |
np.pi * z))[:, :, None] | |
self.bz = jnp.array(bz) | |
def tree_flatten(self): | |
children = None | |
aux = (self.positions, self.N) | |
return children, aux | |
@classmethod | |
def tree_unflatten(cls, aux, _): | |
return cls(*aux) | |
def __call__(self, p: Field, u, v): | |
r"""Returns the values of the field p at the sensors positions. | |
Args: | |
p (Field): The field to be sampled. | |
""" | |
pw = jnp.sum(p.on_grid[None] * self.bx, axis=1) | |
pw = jnp.sum(pw * self.by, axis=1) | |
pw = jnp.sum(pw * self.bz, axis=1) | |
return pw |
Author
tomelse
commented
Jun 9, 2023
Should maybe all np calls be replaced with jnp calls to support the auto differentiation capabilities all the way through?
Or is that not necessary here?
Should maybe all np calls be replaced with jnp calls to support the auto differentiation capabilities all the way through? Or is that not necessary here?
Not really sure, I don't know enough about how JAX handles that kind of thing internally. Can't hurt to change it anyway so will do so! (I guess maybe because bx/by/bz are constants it doesn't affect the differentiation?).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment