Skip to content

Instantly share code, notes, and snippets.

@jgomezdans
Last active December 29, 2023 21:03
Show Gist options
  • Save jgomezdans/6560e2971904794d91298f741acbbfc7 to your computer and use it in GitHub Desktop.
Save jgomezdans/6560e2971904794d91298f741acbbfc7 to your computer and use it in GitHub Desktop.
testing_frameworks
import numpy as np
import numba
import jax.numpy as jnp
import jax
import time
import torch
import pandas as pd
from functools import partial
def spatial_regularisation_naive(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation. Assume that a 2D array of size `ny` x `nx`
is stored in 1D array `x`. We want to calculate the difference between the centre
pixel and a neighbourhood given by `dy` and `dx`, square it and sum it.
Parameters
----------
x (np.ndarray): a `ny`*`nx` vector
dy (int): neighbourhood around pixel from -dy to (dy+1)
dx (int): neighbourhood around pixel from -dx to (dx+1)
Returns
-------
Associated cost
"""
ny, nx = x.shape
total_cost = 0.0
for i in range(ny):
for j in range(nx):
for m in range(-dy, dy + 1):
for n in range(-dx, x + 1):
if (0 <= (i + m) <= ny) and (0 <= (j + m) <= nx):
total_cost += (x[i, j] - x[i + m, j + n]) ** 2
return 0.5 * total_cost
def spatial_regularisation_numpy(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using numpy for vectorization.
Parameters
----------
x : np.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Returns
-------
float
Associated cost.
"""
ny, nx = x.shape
total_cost = 0.0
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
continue
# Create shifted versions of the array and calculate the squared difference
shifted_x = np.roll(x, shift=m, axis=0)
shifted_y = np.roll(shifted_x, shift=n, axis=1)
squared_diff = (x - shifted_y) ** 2
total_cost += np.sum(squared_diff)
return 0.5 * total_cost
@numba.jit(nopython=True, parallel=True, fastmath=True)
def spatial_regularisation_numba(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using Numba with parallelization.
Parameters
----------
x : np.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Returns
-------
float
Associated cost.
"""
ny, nx = x.shape
total_cost = 0.0
for i in numba.prange(ny):
for j in numba.prange(nx):
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if 0 <= i + m < ny and 0 <= j + n < nx:
if not (m == 0 and n == 0):
total_cost += (x[i, j] - x[i + m, j + n]) ** 2
return 0.5 * total_cost
@partial(jax.jit, static_argnums=(1, 2))
def spatial_regularisation_jax(x: jnp.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using JAX for efficient computation.
Parameters
----------
x : jnp.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Returns
-------
float
Associated cost.
"""
total_cost = 0.0
# Iterate over the neighborhood offsets
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
continue
# Shift the array and compute the squared difference
shifted_x = jnp.roll(x, shift=m, axis=0)
shifted_y = jnp.roll(shifted_x, shift=n, axis=1)
squared_diff = (x - shifted_y) ** 2
# Ensure we only sum valid comparisons (ignoring the padded edges)
valid_mask = jnp.ones_like(x, dtype=bool)
if m > 0:
valid_mask = valid_mask.at[:m, :].set(False)
elif m < 0:
valid_mask = valid_mask.at[m:, :].set(False)
if n > 0:
valid_mask = valid_mask.at[:, :n].set(False)
elif n < 0:
valid_mask = valid_mask.at[:, n:].set(False)
total_cost += jnp.sum(squared_diff * valid_mask)
return 0.5 * total_cost
@torch.jit.script
def spatial_regularisation_torch(
x: torch.Tensor, dy: int = 1, dx: int = 1
) -> torch.Tensor:
"""Calculate spatial regularisation using PyTorch.
Parameters
----------
x : torch.Tensor
a 2D tensor.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Returns
-------
torch.Tensor
Associated cost.
"""
ny, nx = x.shape
total_cost = torch.tensor(0.0, device=x.device)
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
continue
# Shift the tensor and compute the squared difference
shifted_x = torch.roll(x, shifts=m, dims=0)
shifted_y = torch.roll(shifted_x, shifts=n, dims=1)
squared_diff = (x - shifted_y) ** 2
# Ensure we only sum valid comparisons (ignoring the padded edges)
valid_mask = torch.ones_like(x, dtype=torch.bool)
if m > 0:
valid_mask[:m, :] = False
elif m < 0:
valid_mask[m:, :] = False
if n > 0:
valid_mask[:, :n] = False
elif n < 0:
valid_mask[:, n:] = False
total_cost += torch.sum(squared_diff * valid_mask)
return 0.5 * total_cost
if __name__ == "__main__":
sizes = [128, 256, 1024, 2048]
neighbourhood = [1, 3, 5, 7, 9]
functions = [
spatial_regularisation_numpy,
spatial_regularisation_numba,
spatial_regularisation_jax,
spatial_regularisation_torch,
]
results = []
num_runs = 5
for npix in sizes:
for n_neighs in neighbourhood:
x = np.random.rand(npix, npix)
xx = jnp.array(x)
xxx = torch.from_numpy(x)
for func, array_in in zip(functions, [x, x, xx, xxx]):
try:
func_name = func.__name__
print(npix, n_neighs, func.__name__)
except AttributeError:
func_name = func.name
print(npix, n_neighs, func.name)
total_time = []
# Dry run
_ = func(array_in, dx=n_neighs, dy=n_neighs)
for _ in range(num_runs):
start_time = time.perf_counter()
_ = func(array_in, dx=n_neighs, dy=n_neighs)
end_time = time.perf_counter()
total_time.append(end_time - start_time)
avg_time = np.mean(total_time)
std_time = np.std(total_time)
func(array_in, dx=n_neighs, dy=n_neighs)
results.append([npix, n_neighs, func_name, avg_time, std_time])
df = pd.DataFrame(results)
df.columns=["n_size", "dx/dy", "function", "time_mean", "time_std"]
df["library"] = pd.DataFrame(df.function.str.split("_").to_list()).iloc[:,-1]
df.groupby('library')['time_mean'].plot(x="n_size", legend=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment