Skip to content

Instantly share code, notes, and snippets.

@tomelse
Last active September 29, 2023 12:00
Show Gist options
  • Save tomelse/f9ba7508b75f44f34ebdbf25d5f5b0a3 to your computer and use it in GitHub Desktop.
Save tomelse/f9ba7508b75f44f34ebdbf25d5f5b0a3 to your computer and use it in GitHub Desktop.
J-Wave band-limited interpolant for off-grid sensors (http://github.com/ucl-bug/jwave). This is based on this paper : http://bug.medphys.ucl.ac.uk/papers/2019-Wise-JASA.pdf. It is similar to what has been implemented in k-Wave (www.k-wave.org/forum/topic/alpha-version-of-kwavearray-off-grid-sources).
import jax.numpy as jnp
import numpy.typing
from jax.tree_util import register_pytree_node_class
import numpy as np
from typing import Tuple
from jwave.geometry import Field
@register_pytree_node_class
class BLISensors:
r"""Sensors structure
Attributes:
positions (Tuple[List[int]]): sensors positions
"""
positions: Tuple[np.typing.ArrayLike]
def __init__(self, positions, N):
self.positions = positions
self.N = N
x = np.arange(N[0])[None]
bx = (1 / N[0]) * (np.sin(np.pi * (positions[0][:, None] - x)) / np.tan(
np.pi * (positions[0][:, None] - x) / N[0]) - np.sin(np.pi * positions[0][:, None]) * np.sin(
np.pi * x))[:, :, None, None, None]
self.bx = jnp.array(bx)
y = np.arange(N[1])[None]
by = (1 / N[1]) * (np.sin(np.pi * (positions[1][:, None] - y)) / np.tan(
np.pi * (positions[1][:, None] - y) / N[1]) - np.sin(np.pi * positions[1][:, None]) * np.sin(
np.pi * y))[:, :, None, None]
self.by = jnp.array(by)
z = np.arange(N[2])[None]
bz = (1 / N[2]) * (np.sin(np.pi * (positions[2][:, None] - z)) / np.tan(
np.pi * (positions[2][:, None] - z) / N[2]) - np.sin(np.pi * positions[2][:, None]) * np.sin(
np.pi * z))[:, :, None]
self.bz = jnp.array(bz)
def tree_flatten(self):
children = None
aux = (self.positions, self.N)
return children, aux
@classmethod
def tree_unflatten(cls, aux, _):
return cls(*aux)
def __call__(self, p: Field, u, v):
r"""Returns the values of the field p at the sensors positions.
Args:
p (Field): The field to be sampled.
"""
pw = jnp.sum(p.on_grid[None] * self.bx, axis=1)
pw = jnp.sum(pw * self.by, axis=1)
pw = jnp.sum(pw * self.bz, axis=1)
return pw
@jgroehl
Copy link

jgroehl commented Jun 9, 2023

Should maybe all np calls be replaced with jnp calls to support the auto differentiation capabilities all the way through?
Or is that not necessary here?

@tomelse
Copy link
Author

tomelse commented Jun 9, 2023

Should maybe all np calls be replaced with jnp calls to support the auto differentiation capabilities all the way through? Or is that not necessary here?

Not really sure, I don't know enough about how JAX handles that kind of thing internally. Can't hurt to change it anyway so will do so! (I guess maybe because bx/by/bz are constants it doesn't affect the differentiation?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment