Skip to content

Instantly share code, notes, and snippets.

@astanziola
Created February 4, 2023 13:57
Show Gist options
  • Save astanziola/565681777ed2a6231bbd7184bfb4b80e to your computer and use it in GitHub Desktop.
Save astanziola/565681777ed2a6231bbd7184bfb4b80e to your computer and use it in GitHub Desktop.
Differentiable approximate plane wave Ultrasound PSF in JAX
# This code is a quick reproduction of eq. (9) of
# "Mathematical Analysis of Ultrafast Ultrasound Imaging" by Alberti ed at. 2016.
# https://arxiv.org/pdf/1604.04604.pdf
#
# It represents an approximate point spread function for Plane Wave Imaging that can be
# used to write a simple, yet powerful, 2D Plane Wave ultrasound simulator using spatially-variant
# convolutions.
#
# The function is fully differentiable.
import jax
from jax import numpy as jnp
def f(t, *, v0, tau):
X = lambda u : jnp.exp(-(u**2)/(tau**2))
return jnp.exp(2*jnp.pi*1j*v0*t) * X(v0*t)
def f_prime(t, v0, tau):
f_set = partial(f, v0=v0, tau=tau)
primals, f_vjp = jax.vjp(f_set, t)
return f_vjp(1.0 + 0*1j)[0] + f_vjp(1j)[0]*1j
def plane_wave_psf(
x,
z,
theta, # Transmit angle
c0, # Background sound speed
F, # Aperture size
v0, # Base frequency
tau, # Pulse width
):
# https://arxiv.org/pdf/1604.04604.pdf Eq. (9)
prefact = c0/(4*jnp.pi*x)
aperture_prefact = 1/(c0 * jnp.sqrt(1 + F**2))
z_component = (1 + jnp.sqrt(1 + F**2)*jnp.cos(theta))*z
x_component_left = (jnp.sqrt(1 + F**2)*jnp.sin(theta) - F)*x
x_component_right = (jnp.sqrt(1 + F**2)*jnp.sin(theta) + F)*x
f_1 = partial(f_prime, v0=v0, tau=tau)
square_bracket = f_1(aperture_prefact * (z_component + x_component_left)) - f_1(aperture_prefact * (z_component + x_component_right))
return prefact * square_bracket
psf_line = jax.vmap(plane_wave_psf, in_axes=(0,None,None,None,None,None,None))
psf_fun = jax.vmap(psf_line, in_axes=(None,0,None,None,None,None,None))
# Example usage.
# The following code obtains a 2D PSF
x = jnp.linspace(-0.003, 0.003, 1000) # Spatial coordinates
z = jnp.linspace(-0.003, 0.003, 1000)
aperture = 0.05 # Aperture size
z = 0.15 # Depth
F = aperture / z # 1/f-number (as defined in Alberti et al.)
theta = 10*jnp.pi/180 # Steering angle in radiants
c0 = 1540 # Background sound sspeed
v0 = 5e6 # Center frequency
tau = 2 # ~ number of cycles
result = psf_fun(x, z, -theta, c0, F, v0, tau)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment