Skip to content

Instantly share code, notes, and snippets.

@ckesanapalli
Last active June 30, 2023 14:37
Show Gist options
  • Save ckesanapalli/bbcdc059aa845aff6b62a48b6e33d200 to your computer and use it in GitHub Desktop.
Save ckesanapalli/bbcdc059aa845aff6b62a48b6e33d200 to your computer and use it in GitHub Desktop.
Interpolate NaN values in ND array.
import numpy as np
from scipy.interpolate import griddata
def interp_nan(data, **griddata_kwargs):
"""
Interpolate NaN values in ND array.
Parameters
----------
data : np.ndarray
Data with NaN values.
griddata_kwargs : dict
scipy.interpolate.griddata key word arguments.
Returns
-------
np.ndarray
Data array with interpolated values.
Example
-------
>>> data = np.arange(1.0, 10.0)
>>> data[1] = np.nan
>>> data[2] = np.nan
>>> data
array([ 1., nan, nan, 4., 5., 6., 7., 8., 9.])
>>> data_full=interp_nan(data)
array([[1., 2., 3., 4., 5., 6., 7., 8., 9.]])
"""
axes = np.flip(np.mgrid[*map(slice, data.shape)], axis=0)
data_nd_masked = np.ma.masked_invalid(data)
valid_data_nd = data_nd_masked[~data_nd_masked.mask].ravel()
valid_axes = axes[:,~data_nd_masked.mask]
return griddata(tuple(valid_axes), valid_data_nd, tuple(axes), **griddata_kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment