Skip to content

Instantly share code, notes, and snippets.

Last active May 25, 2022 15:55
Show Gist options
  • Save larsoner/fbe32d57996848395854d5e59dff1e10 to your computer and use it in GitHub Desktop.
Save larsoner/fbe32d57996848395854d5e59dff1e10 to your computer and use it in GitHub Desktop.
Displacement field demo using matched points
from mne.transforms import _MatchedDisplacementFieldInterpolator
import numpy as np
import matplotlib.pyplot as plt
# Warp from one set of matched points to another using a
# nonlinear displacement field using matched points
# Use the example like
# but in 3D
to = np.array([[5, 4, 1], [6, 1, 0], [4, -1, 1], [3, 3, 0]], float)
fro = np.array([[0, 2, 2], [2, 2, 1], [2, 0, 2], [0, 0, 1]], float)
ndim = fro.shape[-1]
n_grid = 20
grid = np.array(np.meshgrid(*[np.linspace(-0.5, 2.5, n_grid)] * 2, [1] * (ndim - 2))).T.reshape(-1, ndim) # noqa: E501
grid_c = grid[:, :2] - grid.min(axis=0)[:2]
grid_c = grid_c / grid_c.max(axis=0)
grid_c = np.array([grid_c[:, 0], np.zeros_like(grid_c[:, 0]), grid_c[:, 1]]).T
assert grid.shape == (n_grid ** 2, ndim)
fig, axes = plt.subplots(2, figsize=(6, 6), sharex=True, sharey=True)
colors = plt.get_cmap('YlGn')(np.linspace(0.25, 1, to.shape[0]))
axes[0].scatter(*to[:, :2].T, c=colors, edgecolors='none', zorder=5, lw=0)
axes[0].scatter(*fro[:, :2].T, c=colors, marker='x', zorder=4, lw=2)
axes[0].scatter(*grid[:, :2].T, c=grid_c, marker='.', alpha=0.2, lw=2)
interp = _MatchedDisplacementFieldInterpolator(fro, to)
fro_t = interp(fro)
grid_t = interp(grid)
axes[1].scatter(*to[:, :2].T, c=colors, edgecolors='none', zorder=5, lw=0)
axes[1].scatter(*fro_t[:, :2].T, c=colors, marker='x', zorder=4, lw=2)
axes[1].scatter(*grid_t[:, :2].T, c=grid_c, marker='.', alpha=0.2, lw=2)
axes[1].scatter(*interp._extrema[:, :2].T, c='k', marker='d')
# Added:
class _MatchedDisplacementFieldInterpolator:
"""Interpolate from matched points using a displacement field in ND."""
def __init__(self, fro, to):
from scipy.interpolate import LinearNDInterpolator
fro = np.array(fro, float)
to = np.array(to, float)
assert fro.shape == to.shape
assert fro.ndim == 2
# this restriction is only necessary because it's what
# _fit_matched_points requires
assert fro.shape[1] == 3
# Prealign using affine + uniform scaling
trans, scale = _fit_matched_points(fro, to, scale=True)
trans = _quat_to_affine(trans)
trans[:3, :3] *= scale
self._affine = trans
fro = apply_trans(trans, fro)
# Add points at extrema
delta = (to.max(axis=0) - to.min(axis=0)) / 2.
extrema = np.array([fro.min(axis=0) - delta, fro.max(axis=0) + delta])
self._extrema = np.array(
np.meshgrid(*extrema.T)).T.reshape(-1, fro.shape[-1])
fro_concat = np.concatenate((fro, self._extrema))
to_concat = np.concatenate((to, self._extrema))
# Compute the interpolator (which internally uses Delaunay)
self._interp = LinearNDInterpolator(fro_concat, to_concat)
def __call__(self, x):
assert x.ndim in (1, 2) and x.shape[-1] == 3
singleton = x.ndim == 1
out = self._interp(apply_trans(self._affine, x))
out = out[0] if singleton else out
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment