Skip to content

Instantly share code, notes, and snippets.

@astanziola
Created November 7, 2022 09:32
Show Gist options
  • Save astanziola/b8409842e6d6bb35e4a06173e8d60977 to your computer and use it in GitHub Desktop.
Save astanziola/b8409842e6d6bb35e4a06173e8d60977 to your computer and use it in GitHub Desktop.
jwave_intro_clean.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMmzMBg+u90WMfXpILs0D5Q",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/astanziola/b8409842e6d6bb35e4a06173e8d60977/jwave_intro_clean.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# j-Wave Crash Course\n",
"\n",
"*Goal*: To understand the basics of j-Wave.\n",
"\n",
"### Before starting\n",
"\n",
"Make sure you are using a GPU runtime!"
],
"metadata": {
"id": "dNQHhelkJXtQ"
}
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/ucl-bug/jwave.git"
],
"metadata": {
"id": "oLo567vuJJyV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Part 1 - Prerequisites\n",
"\n",
"For using `jwave`, you need to be familiar with `jax` and the package `jaxdf`. This page provides a quick introduction to allow you to get started with running simulations."
],
"metadata": {
"id": "RaE34X9-KAsL"
}
},
{
"cell_type": "markdown",
"source": [
"## JAX\n",
"\n",
"[jax](https://github.com/google/jax) is a Python library for machine learning and scientific computing, on which `jwave` is based. To learn how to use `jax`, please refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html). For our purposes, the two main features of `jax` that one needs to keep in mind for using `jwave`:\n",
"\n",
"#### 1. `jax` is a drop-in replacement for `numpy`\n",
"\n",
"Writing functions that operate on arrays in `jax` is extremely similar to doing it in NumPy. For example, we can write a function that calculates a polynomial function of the input and apply it to an array"
],
"metadata": {
"id": "jd31ydwUKGOv"
}
},
{
"cell_type": "code",
"source": [
"from jax import numpy as jnp\n",
"\n",
"def f(...):\n",
" pass\n",
"\n",
"# Testing function\n",
"x = 2.\n",
"y = 3.\n",
"\n",
"z = f(x, y, 3.,2.)\n",
"\n",
"print(f\"(x,y) = {[x,y]}\\tf(x,y) = {z}\")"
],
"metadata": {
"id": "f4ArGHmQJJvR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"However, some operations, like in-place updates of arrays, are not permitted. For more details on this, see the [Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) section of the `jax` documentation."
],
"metadata": {
"id": "y5SpTn7uKNep"
}
},
{
"cell_type": "markdown",
"source": [
"#### 2. `jax` power comes from function transformations.\n",
"\n",
"One of the main features of `jax` is function transformations. It is a concept that may be unfamiliar, especially for users coming from MATLAB (a related concept is [function handles](https://uk.mathworks.com/help/matlab/matlab_prog/call-local-functions-using-function-handles.html)). \n",
"\n",
"The fundamental idea is to have special functions (from now on called *function transformations* or [high-order functions](https://en.wikipedia.org/wiki/Higher-order_function)) that take a function as input and return a function ad output. They are also related to the concept of [python decorator](https://peps.python.org/pep-0318/). \n",
"\n",
"To make things concrete, an example of function transformation (say $\\mathcal{T}$) is one that transform a generic $f(x,y)$ by swapping the $x$ and $y$, yelding the new function $T(f)(x,y) = g(x,y) = f(y,x)$."
],
"metadata": {
"id": "IEfcmy5uKUI3"
}
},
{
"cell_type": "code",
"source": [
"def swap_coordinates(f):\n",
" pass\n",
" \n",
"# Gets the new function\n",
"f_swapped = swap_coordinates(f)\n",
"\n",
"# Evaluates the new function\n",
"w = f_swapped(x, y, 3., 2.)\n",
"\n",
"print(f\"(x,y) = {[x,y]}\\t f(x,y) = {z}\")\n",
"print(f\"(x,y) = {[x,y]}\\tg(x,y) = f(y,x) = {w}\")"
],
"metadata": {
"id": "9ncN8RIQJJrr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"- Make a decorator that squares the input before feeding them to the function `f`"
],
"metadata": {
"id": "qFckHMM4KquK"
}
},
{
"cell_type": "code",
"source": [
"def square_coords_coordinates(f):\n",
" pass\n",
" \n",
"# Gets the new function\n",
"f_squared = square_coords_coordinates(f)\n",
"\n",
"# Evaluates the new function\n",
"w = f_squared(x, y, 3., 2.)\n",
"\n",
"print(f\"(x,y) = {[x,y]}\\t f(x,y) = {z}\")\n",
"print(f\"(x,y) = {[x,y]}\\tg(x,y) = f(y,x) = {w}\")"
],
"metadata": {
"id": "TAndqJ8QKesV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"You could also define a decorator that changes the output for the function, e.g. returning mutliple outputs. Or combinations of the two.\n",
"\n",
"## Taking gradients.\n",
"\n",
"`jax` comes equipped with many transformations that are useful for machine learning research. One of those transformations is `jax.grad`, which [applies the gradient operator](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) to a function with a scalar output.\n",
"\n",
"Recall the definition of gradient\n",
"\n",
"> In vector calculus, the gradient of a scalar-valued differentiable function $f$ of several variables is the vector field (or vector-valued function) $\\nabla f$ whose value at a point $p$ is the vector whose components are the partial derivatives of $f$ at $p$. \n",
"\n",
"> That is, for ${\\displaystyle f\\colon \\mathbb {R} ^{n}\\to \\mathbb {R} }$, its gradient is the function ${\\displaystyle \\nabla f\\colon \\mathbb {R} ^{n}\\to \\mathbb {R} ^{n}}$ that is defined at the point ${\\displaystyle p=(x_{1},\\ldots ,x_{n})}$ and returns the vector ${\\displaystyle \\nabla f(p)={\\begin{bmatrix}{\\frac {\\partial f}{\\partial x_{1}}}(p)\\\\\\vdots \\\\{\\frac {\\partial f}{\\partial x_{n}}}(p)\\end{bmatrix}}.}$\n",
"\n",
"It is useful to visually compare this against the approach taken by other libraries\n",
"\n",
"![](https://sjmielke.com/images/blog/jax-purify/comparison_big.png)"
],
"metadata": {
"id": "-wIKMvg9LDRb"
}
},
{
"cell_type": "code",
"source": [
"import jax\n",
"\n",
"# Gets the gradient function\n",
"...\n",
"\n",
"# Evaluates the new function\n",
"...\n",
"\n",
"print(f\"(x,y) = {[x,y]}\\tf(x,y) = {z} \\tf'(x,y) = {[str(x_prime), str(y_prime)]}\")"
],
"metadata": {
"id": "ygxznlQUJJpQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Another important function transformation is [`jax.jit`](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions), which returns a version of the function optimized and compiled for the user hardware"
],
"metadata": {
"id": "_QT2WrEtLNrM"
}
},
{
"cell_type": "code",
"source": [
"def complex_fun(x):\n",
" pass\n",
"\n",
"# Gets the compiled function\n",
"f_jit = jax.jit(complex_fun)\n",
"\n",
"# Evaluates the new function\n",
"x = jnp.ones((10000,10000));\n",
"z = complex_fun(x)\n",
"z_jit = f_jit(x) # The function is compiled at its first call"
],
"metadata": {
"id": "5fjHkhFqLNAQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit complex_fun(x)"
],
"metadata": {
"id": "7Y9Q-zl2JJmu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit f_jit(x).block_until_ready()"
],
"metadata": {
"id": "B1mvUBxYJJkR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Fields\n",
"\n",
"**Objects** are variables that contain the numerical data used during the simulations. They are often defined as classes registered to the JAX compiler as a custom PyTree node, meaning that they can be passed to functions in the same way as `jax.numpy` arrays, and they can be initialized within a function. \n",
"\n",
"One example is the `Domain` class from `jwave.geometry`, which defines the domain where the simulation takes place."
],
"metadata": {
"id": "ey1LWY21LX6Z"
}
},
{
"cell_type": "code",
"source": [
"..."
],
"metadata": {
"id": "3CSNbWPXJJhe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Other objects are defined as `Field` subclasses from [`jaxdf`](https://github.com/ucl-bug/jaxdf). A `Field`, in a nutshell, consists of parameters and a representation (or discretization), associated with a domain. For example, the following defines a truncated Fourier Series, for which the value in $(0,0)$ is 1."
],
"metadata": {
"id": "5Hy-yIarLcK8"
}
},
{
"cell_type": "code",
"source": [
"..."
],
"metadata": {
"id": "l5cxlZFCJJel"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can visualize this field on the grid"
],
"metadata": {
"id": "PRvzCkcMLeWt"
}
},
{
"cell_type": "code",
"source": [
"from jwave.utils import show_field\n",
"\n",
"show_field(u.on_grid[:,:,0])"
],
"metadata": {
"id": "pDlx1vOdJIf9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Or query the field at a specific point in space"
],
"metadata": {
"id": "krlUCgTeLudV"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "25g2tmbhJHdz"
},
"outputs": [],
"source": [
"x = jnp.asarray([0., 0.]) # Poit to query\n",
"print(u(x))"
]
},
{
"cell_type": "markdown",
"source": [
"We can also use the latter callable to visualize the underlying discretization, i.e. the bandlimited interpolant, by querying the field at values outside the grid nodes. Here we are usign the [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) function transformation, to efficiently query the field at multiple points. "
],
"metadata": {
"id": "GrnOtej0Lxji"
}
},
{
"cell_type": "code",
"source": [
"...\n",
"\n",
"show_field(z)"
],
"metadata": {
"id": "gVzuHUJMLv5q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Lastly, there are *operators* that can be applied to `Field` objects. Operators are objects that represent some mathematical operator on functions, such as the gradient or the Helmholtz equation, whose numerical implementation depends on the input discretization. They return functions and they can be used inside a function to be transformed using `jax`."
],
"metadata": {
"id": "u9GSahksL7sI"
}
},
{
"cell_type": "code",
"source": [
"from jaxdf.operators import laplacian\n",
"\n",
"# Sample field\n",
"params = jnp.zeros(domain.N)\n",
"params = params.at[32:46,25:58].set(1.0)\n",
"u = FourierSeries(params, domain)\n",
"\n",
"show_field(u.on_grid[:,:,0])"
],
"metadata": {
"id": "cmyW3oriL2so"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"v = ...\n",
"show_field(v.on_grid[:,:,0])"
],
"metadata": {
"id": "y6VgO2FCL9UI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"It is still a Field, so it can be visualized it at higher resolutions using the same method as before"
],
"metadata": {
"id": "y3fKXD5-MWEO"
}
},
{
"cell_type": "code",
"source": [
"field_on_plane = jax.vmap(jax.vmap(v))\n",
"\n",
"x = jnp.linspace(-32, 32, 200)\n",
"X, Y = jnp.meshgrid(x,x)\n",
"coords = jnp.stack([X,Y], -1)\n",
"z = field_on_plane(coords).real[...,0]\n",
"z = jnp.fliplr(jnp.fliplr(z).T)\n",
"\n",
"show_field(z)"
],
"metadata": {
"id": "FD2kA9EXMCwK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Part 2 - Full Wave Inversion\n",
"\n",
"This tutorial shows how to developa (very idealized) Full Waveform Inversion algorithms with `jwave`. To highlight the ability of customizing the inversion algorithm, we'll add a smoothing step to the gradient computation before performing the gradient descent step with the Adam optimizer.\n",
"\n",
"## Setup simulation\n",
"Let's start by importing the required modules"
],
"metadata": {
"id": "odjGsvTPMH4B"
}
},
{
"cell_type": "code",
"source": [
"!pip install tqdm"
],
"metadata": {
"id": "gEATXHuXMGK5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from functools import partial\n",
"\n",
"import numpy as np\n",
"from jax import grad, jit, lax, nn\n",
"from jax import numpy as jnp\n",
"from jax import random, value_and_grad, vmap\n",
"from jax.example_libraries import optimizers\n",
"from matplotlib import pyplot as plt\n",
"from tqdm import tqdm\n",
"\n",
"from jwave import FourierSeries\n",
"from jwave.acoustics import simulate_wave_propagation\n",
"from jwave.geometry import (\n",
" Domain,\n",
" Medium,\n",
" Sensors,\n",
" Sources,\n",
" TimeAxis,\n",
" _circ_mask,\n",
" _points_on_circle,\n",
")\n",
"from jwave.signal_processing import apply_ramp, gaussian_window, smooth"
],
"metadata": {
"id": "IeGba0clMxnx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Settings\n",
"N = (256, 256)\n",
"dx = (0.1e-3, 0.1e-3)\n",
"cfl = 0.25\n",
"num_sources = 64\n",
"source_freq = 1e6\n",
"source_mag = 1.3e-5\n",
"random_seed = random.PRNGKey(42)\n",
"\n",
"# Define domain\n",
"...\n",
"\n",
"# Define medium\n",
"...\n",
"\n",
"# Time axis\n",
"...\n",
"\n",
"# Sources\n",
"...\n",
"\n",
"# Sensors\n",
"..."
],
"metadata": {
"id": "rPm_dvtyMy8-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's visualize the simulation settings to get a better understanding of the setup:"
],
"metadata": {
"id": "tAqZBNdkM495"
}
},
{
"cell_type": "code",
"source": [
"# Show simulation setup\n",
"fig, ax = plt.subplots(1, 2, figsize=(15, 4), gridspec_kw={\"width_ratios\": [1, 2]})\n",
"\n",
"ax[0].imshow(medium.sound_speed.on_grid[:,:,0], cmap=\"gray\")\n",
"ax[0].scatter(\n",
" source_positions[1], source_positions[0], c=\"r\", marker=\"x\", label=\"sources\"\n",
")\n",
"ax[0].scatter(\n",
" sensors_positions[1], sensors_positions[0], c=\"g\", marker=\".\", label=\"sensors\"\n",
")\n",
"ax[0].legend(loc=\"lower right\")\n",
"ax[0].set_title(\"Sound speed\")\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].plot(signal, label=\"Source 1\", c=\"k\")\n",
"ax[1].set_title(\"Source signals\")\n",
"#ax[1].get_yaxis().set_visible(False)"
],
"metadata": {
"id": "a-oqMKVGM4er"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"plt.imshow(medium.sound_speed.on_grid[:,:,0])\n",
"plt.colorbar()"
],
"metadata": {
"id": "YW8-bUhtM1Sn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Run the simulation\n",
"\n",
"At this point, we can finally use the `simulate_wave_propagation` to compute the wave propagation, with the given sound speed map and a pulse transmitted from the requested sensor."
],
"metadata": {
"id": "4XWytlWoNGrR"
}
},
{
"cell_type": "code",
"source": [
"src_signal = jnp.stack([signal])\n",
"\n",
"# We can compile the entire function! All the constructors\n",
"# that don't depend on the inputs will be statically compiled\n",
"# and run only once.\n",
"@jit\n",
"def single_source_simulation(sound_speed, source_num):\n",
" # Setting source\n",
" # Equivalent to x = source_positions[0][source_num]\n",
" ...\n",
"\n",
" # Updating medium with the input speed of sound map\n",
" ...\n",
"\n",
" # Run simulations\n",
" ..."
],
"metadata": {
"id": "IuSFYVNPM_0G"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"p = single_source_simulation(medium.sound_speed, 20)\n",
"\n",
"# Visualize the acoustic traces\n",
"plt.figure(figsize=(6, 4.5))\n",
"maxval = jnp.amax(jnp.abs(p))\n",
"plt.imshow(\n",
" p, cmap=\"RdBu_r\", vmin=-1, vmax=1, interpolation=\"nearest\", aspect=\"auto\"\n",
")\n",
"plt.colorbar()\n",
"plt.title(\"Acoustic traces\")\n",
"plt.xlabel(\"Sensor index\")\n",
"plt.ylabel(\"Time\")\n",
"plt.show()"
],
"metadata": {
"id": "V5BEcpAYNJHT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"single_source_simulation(medium.sound_speed, 1).block_until_ready()"
],
"metadata": {
"id": "xfbTUNEhNMSn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"batch_simulations = vmap(single_source_simulation, in_axes=(None, 0))\n",
"p_data = batch_simulations(medium.sound_speed, jnp.arange(num_sources))\n",
"print(f\"Size of data [Source idx, Time, Sensor idx]: {p_data.shape}\")"
],
"metadata": {
"id": "bLYUseGhNN6_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"batch_simulations(medium.sound_speed, jnp.arange(num_sources)).block_until_ready()"
],
"metadata": {
"id": "IQlfefMFNPfH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"noise = np.random.normal(size=p_data.shape) * 0.2\n",
"p_data = p_data + noise\n",
"\n",
"plt.plot(p_data[12, :, 0])\n",
"plt.title(\"Example of noisy traces\")\n",
"plt.show()"
],
"metadata": {
"id": "gYXg37enNRUW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Define the optimization problem\n",
"\n",
"We'll let `autodiff` generate the full wave inversion algorithm. To do so, we need to specify a forward model that maps the speed of sound map to the acoustic data. \n",
"\n",
"While this has already be done in the previous cells of this notebook, we'll wrap it around a new function used to reparametrize the speed of sound map. This is done because the forward simulation is unstable for speed of sound values below a certain treshold $T$. To make sure that such maps are not in the range of the inversion algorithm, the speed of sound map is parametrized as\n",
"$$\n",
"c = T + \\text{sigmoid}(c')\n",
"$$\n",
"\n",
"Also, we'll mask the pixels outside the circle defined by the sensors, since we know that the only unknowns are the pixel values inside of it."
],
"metadata": {
"id": "wgaPswBlNYCn"
}
},
{
"cell_type": "code",
"source": [
"from jaxdf.operators import compose\n",
"\n",
"..."
],
"metadata": {
"id": "LJEAfUUkNVi-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The goal now is to define the function to differentiate, that is some function $L$ such that\n",
"\n",
"$L(c) \\to \\text{loss}$\n",
"\n",
"in order to compute\n",
"\n",
"$\\frac{\\partial L}{\\partial c}(c) \\to c$"
],
"metadata": {
"id": "qwx4jSggZUh0"
}
},
{
"cell_type": "code",
"source": [
"from jwave.signal_processing import analytic_signal\n",
"from jaxdf.operators import gradient, functional\n",
"\n",
"def loss_func(params, source_num):\n",
" pass\n",
"\n",
"loss_with_grad = value_and_grad(loss_func, argnums=0)"
],
"metadata": {
"id": "ZI4vMq0yNhVU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def smooth_fun(gradient):\n",
" pass\n",
"\n",
"# Viualize\n",
"plt.figure(figsize=(8, 6))\n",
"plt.imshow(gradient.on_grid[:,:,0], cmap=\"RdBu_r\", vmin=-0.0003, vmax=0.0003)\n",
"plt.title(\"Smoothed gradient\")\n",
"plt.colorbar()\n",
"plt.show()"
],
"metadata": {
"id": "iFhoslwWNrZb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Minimize the objective function\n",
"\n",
"Equipped with a function that calculates the correct gradients, we can finally run the FWI algorithm by randomly looping trough the sources and update the speed of sound estimate.\n",
"\n",
"Following the spirit of `JAX`, all that is needed to do is to write an `update` function that calculates the gradients and applies a step of the optimization algorithm (but we could have used full batch methods, such as BFGS). In this function, we'll also smooth the gradients: this is not necessarily a smart thing to do, but we use it here to highlight how the user can customize any step of the algorithms developed using `jwave`."
],
"metadata": {
"id": "qlWVdvvQN1HG"
}
},
{
"cell_type": "code",
"source": [
"losshistory = []\n",
"reconstructions = []\n",
"num_steps = 100\n",
"\n",
"# Define optimizer\n",
"init_fun, update_fun, get_params = optimizers.adam(0.1, 0.9, 0.9)\n",
"opt_state = init_fun(params)\n",
"\n",
"# Define and compile the update function\n",
"...\n",
"\n",
"\n",
"# Main loop\n",
"..."
],
"metadata": {
"id": "519da7DjNyA9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Finally, let's look at the reconstructed image and its evolution during the optimization"
],
"metadata": {
"id": "Mj1roSCAN9U9"
}
},
{
"cell_type": "code",
"source": [
"from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar\n",
"import matplotlib.font_manager as fm"
],
"metadata": {
"id": "h4X9UKlhN__A"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sos_original = get_sound_speed(params).on_grid\n",
"true_sos = sound_speed.on_grid\n",
"vmin = np.amin(true_sos)\n",
"vmax = np.amax(true_sos)\n",
"\n",
"fig, axes = plt.subplots(2, 4, figsize=(10, 5.5))\n",
"\n",
"k = 0\n",
"recs = [24, 39, 69, 99]\n",
"for row in range(2):\n",
" for col in range(4):\n",
" if k == 0:\n",
" axes[row, col].imshow(true_sos[:,:,0], cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
" axes[row, col].scatter(\n",
" sensors_positions[1],\n",
" sensors_positions[0],\n",
" c=\"g\",\n",
" marker=\".\",\n",
" label=\"sensors\",\n",
" )\n",
" axes[row, col].legend(loc=\"lower right\")\n",
" axes[row, col].set_title(\"True speed of sound\")\n",
" axes[row, col].set_axis_off()\n",
" elif k == 1:\n",
" im_original = axes[row, col].imshow(sos_original[:,:,0], cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
" axes[row, col].set_axis_off()\n",
" axes[row, col].set_title(\"Initial guess\")\n",
" \n",
" cbar_ax = fig.add_axes([0.53, 0.54, 0.01, 0.385])\n",
" cbar = plt.colorbar(im_original, cax=cbar_ax)\n",
" cbar.ax.get_yaxis().labelpad = 15\n",
" cbar.ax.set_ylabel('m/s', rotation=270)\n",
" elif k == 2:\n",
" axes[row, col].set_axis_off()\n",
" elif k == 3:\n",
" axes[row, col].plot(losshistory)\n",
" axes[row, col].set_title(\"Loss\")\n",
" #axes[row, col].set_xticks([], [])\n",
" axes[row, col].margins(x=0)\n",
" else:\n",
" axes[row, col].imshow(reconstructions[recs[k - 4]][:,:,0], cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
" axes[row, col].set_axis_off()\n",
" axes[row, col].set_title(\"Step {}\".format(recs[k - 4] + 1))\n",
" k += 1\n",
"\n",
"# Scale bar\n",
"fontprops = fm.FontProperties(size=12)\n",
"scalebar = AnchoredSizeBar(\n",
" axes[-1, -1].transData,\n",
" 100, '1 cm', 'lower right', \n",
" pad=0.3,\n",
" color='white',\n",
" frameon=False,\n",
" size_vertical=2,\n",
" fontproperties=fontprops)\n",
"axes[-1, -1].add_artist(scalebar)\n",
" \n",
"fig.tight_layout()\n",
"\n",
"plt.savefig('fwi.pdf')"
],
"metadata": {
"id": "wOjMhJJFN7X3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Jw_s2afkOBVP"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment