Skip to content

Instantly share code, notes, and snippets.

Last active December 29, 2023 21:03
Show Gist options
  • Save jgomezdans/6560e2971904794d91298f741acbbfc7 to your computer and use it in GitHub Desktop.
Save jgomezdans/6560e2971904794d91298f741acbbfc7 to your computer and use it in GitHub Desktop.
import numpy as np
import numba
import jax.numpy as jnp
import jax
import time
import torch
import pandas as pd
from functools import partial
def spatial_regularisation_naive(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation. Assume that a 2D array of size `ny` x `nx`
is stored in 1D array `x`. We want to calculate the difference between the centre
pixel and a neighbourhood given by `dy` and `dx`, square it and sum it.
x (np.ndarray): a `ny`*`nx` vector
dy (int): neighbourhood around pixel from -dy to (dy+1)
dx (int): neighbourhood around pixel from -dx to (dx+1)
Associated cost
ny, nx = x.shape
total_cost = 0.0
for i in range(ny):
for j in range(nx):
for m in range(-dy, dy + 1):
for n in range(-dx, x + 1):
if (0 <= (i + m) <= ny) and (0 <= (j + m) <= nx):
total_cost += (x[i, j] - x[i + m, j + n]) ** 2
return 0.5 * total_cost
def spatial_regularisation_numpy(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using numpy for vectorization.
x : np.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Associated cost.
ny, nx = x.shape
total_cost = 0.0
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
# Create shifted versions of the array and calculate the squared difference
shifted_x = np.roll(x, shift=m, axis=0)
shifted_y = np.roll(shifted_x, shift=n, axis=1)
squared_diff = (x - shifted_y) ** 2
total_cost += np.sum(squared_diff)
return 0.5 * total_cost
@numba.jit(nopython=True, parallel=True, fastmath=True)
def spatial_regularisation_numba(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using Numba with parallelization.
x : np.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Associated cost.
ny, nx = x.shape
total_cost = 0.0
for i in numba.prange(ny):
for j in numba.prange(nx):
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if 0 <= i + m < ny and 0 <= j + n < nx:
if not (m == 0 and n == 0):
total_cost += (x[i, j] - x[i + m, j + n]) ** 2
return 0.5 * total_cost
@partial(jax.jit, static_argnums=(1, 2))
def spatial_regularisation_jax(x: jnp.ndarray, dy: int = 1, dx: int = 1) -> float:
"""Calculate spatial regularisation using JAX for efficient computation.
x : jnp.ndarray
a `ny`*`nx` vector.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Associated cost.
total_cost = 0.0
# Iterate over the neighborhood offsets
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
# Shift the array and compute the squared difference
shifted_x = jnp.roll(x, shift=m, axis=0)
shifted_y = jnp.roll(shifted_x, shift=n, axis=1)
squared_diff = (x - shifted_y) ** 2
# Ensure we only sum valid comparisons (ignoring the padded edges)
valid_mask = jnp.ones_like(x, dtype=bool)
if m > 0:
valid_mask =[:m, :].set(False)
elif m < 0:
valid_mask =[m:, :].set(False)
if n > 0:
valid_mask =[:, :n].set(False)
elif n < 0:
valid_mask =[:, n:].set(False)
total_cost += jnp.sum(squared_diff * valid_mask)
return 0.5 * total_cost
def spatial_regularisation_torch(
x: torch.Tensor, dy: int = 1, dx: int = 1
) -> torch.Tensor:
"""Calculate spatial regularisation using PyTorch.
x : torch.Tensor
a 2D tensor.
dy : int
neighbourhood around pixel vertically.
dx : int
neighbourhood around pixel horizontally.
Associated cost.
ny, nx = x.shape
total_cost = torch.tensor(0.0, device=x.device)
for m in range(-dy, dy + 1):
for n in range(-dx, dx + 1):
if m == 0 and n == 0:
# Shift the tensor and compute the squared difference
shifted_x = torch.roll(x, shifts=m, dims=0)
shifted_y = torch.roll(shifted_x, shifts=n, dims=1)
squared_diff = (x - shifted_y) ** 2
# Ensure we only sum valid comparisons (ignoring the padded edges)
valid_mask = torch.ones_like(x, dtype=torch.bool)
if m > 0:
valid_mask[:m, :] = False
elif m < 0:
valid_mask[m:, :] = False
if n > 0:
valid_mask[:, :n] = False
elif n < 0:
valid_mask[:, n:] = False
total_cost += torch.sum(squared_diff * valid_mask)
return 0.5 * total_cost
if __name__ == "__main__":
sizes = [128, 256, 1024, 2048]
neighbourhood = [1, 3, 5, 7, 9]
functions = [
results = []
num_runs = 5
for npix in sizes:
for n_neighs in neighbourhood:
x = np.random.rand(npix, npix)
xx = jnp.array(x)
xxx = torch.from_numpy(x)
for func, array_in in zip(functions, [x, x, xx, xxx]):
func_name = func.__name__
print(npix, n_neighs, func.__name__)
except AttributeError:
func_name =
print(npix, n_neighs,
total_time = []
# Dry run
_ = func(array_in, dx=n_neighs, dy=n_neighs)
for _ in range(num_runs):
start_time = time.perf_counter()
_ = func(array_in, dx=n_neighs, dy=n_neighs)
end_time = time.perf_counter()
total_time.append(end_time - start_time)
avg_time = np.mean(total_time)
std_time = np.std(total_time)
func(array_in, dx=n_neighs, dy=n_neighs)
results.append([npix, n_neighs, func_name, avg_time, std_time])
df = pd.DataFrame(results)
df.columns=["n_size", "dx/dy", "function", "time_mean", "time_std"]
df["library"] = pd.DataFrame(df.function.str.split("_").to_list()).iloc[:,-1]
df.groupby('library')['time_mean'].plot(x="n_size", legend=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment