Skip to content

Instantly share code, notes, and snippets.

Last active February 23, 2021 04:07
Show Gist options
  • Save cbur24/93bd87a72ef29e8e5250275bdaa1c117 to your computer and use it in GitHub Desktop.
Save cbur24/93bd87a72ef29e8e5250275bdaa1c117 to your computer and use it in GitHub Desktop.
Calculate geomedian and TMADs on xarray, numpy, or dask arrays
import xarray as xr
import numpy as np
import dask
import hdstats
from odc.algo import randomize, reshape_for_geomedian
from odc.algo._dask import reshape_yxbt
from datacube.utils.geometry import assign_crs
def xr_geomad(ds, axis='time', where=None, **kw):
Calculate geomedian and TMADs
:param ds: xr.Dataset|xr.DataArray|numpy array
Other parameters:
**kwargs -- passed on to pcm.gnmpcm
maxiters : int 1000
eps : float 0.0001
num_threads: int| None None
def gm_tmad(arr, **kw):
arr: a high dimensional numpy array where the last dimension will be reduced.
returns: a numpy array with one less dimension than input.
gm = hdstats.nangeomedian_pcm(arr, **kw)
nt = kw.pop('num_threads', None)
emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
return np.concatenate([gm, emad, smad, bcmad], axis=-1)
def norm_input(ds, axis):
if isinstance(ds, xr.DataArray):
xx = ds
if len(xx.dims) != 4:
raise ValueError("Expect 4 dimensions on input: y,x,band,time")
if axis is not None and xx.dims[3] != axis:
raise ValueError(f"Can only reduce last dimension, expect: y,x,band,{axis}")
return None, xx,
elif isinstance(ds, xr.Dataset):
xx = reshape_for_geomedian(ds, axis)
return ds, xx,
else: # assume numpy or similar
xx_data = ds
if xx_data.ndim != 4:
raise ValueError("Expect 4 dimensions on input: y,x,band,time")
return None, None, xx_data
kw.setdefault('nocheck', False)
kw.setdefault('num_threads', 1)
kw.setdefault('eps', 1e-6)
ds, xx, xx_data = norm_input(ds, axis)
is_dask = dask.is_dask_collection(xx_data)
if where is not None:
if is_dask:
raise NotImplementedError("Dask version doesn't support output masking currently")
if where.shape != xx_data.shape[:2]:
raise ValueError("Shape for `where` parameter doesn't match")
set_nan = ~where
set_nan = None
if is_dask:
if xx_data.shape[-2:] != xx_data.chunksize[-2:]:
xx_data = xx_data.rechunk(xx_data.chunksize[:2] + (-1, -1))
data = da.map_blocks(lambda x: gm_tmad(x, **kw),
chunks=xx_data.chunks[:-2] + (xx_data.chunks[-2][0]+3,),
data = gm_tmad(xx_data, **kw)
if set_nan is not None:
data[set_nan, :] = np.nan
if xx is None:
return data
dims = xx.dims[:-1]
cc = {k: xx.coords[k] for k in dims}
cc[dims[-1]] = np.hstack([xx.coords[dims[-1]].values,['edev', 'sdev', 'bcdev']])
xx_out = xr.DataArray(data, dims=dims, coords=cc)
if ds is None:
return xx_out
ds_out = xx_out.to_dataset(dim='band')
for b in ds.data_vars.keys():
src, dst = ds[b], ds_out[b]
return assign_crs(ds_out,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment