Last active
February 23, 2021 04:07
-
-
Save cbur24/93bd87a72ef29e8e5250275bdaa1c117 to your computer and use it in GitHub Desktop.
Calculate geomedian and TMADs on xarray, numpy, or dask arrays
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
import dask | |
import hdstats | |
from odc.algo import randomize, reshape_for_geomedian | |
from odc.algo._dask import reshape_yxbt | |
from datacube.utils.geometry import assign_crs | |
def xr_geomad(ds, axis='time', where=None, **kw): | |
""" | |
Calculate geomedian and TMADs | |
:param ds: xr.Dataset|xr.DataArray|numpy array | |
Other parameters: | |
**kwargs -- passed on to pcm.gnmpcm | |
maxiters : int 1000 | |
eps : float 0.0001 | |
num_threads: int| None None | |
""" | |
def gm_tmad(arr, **kw): | |
""" | |
arr: a high dimensional numpy array where the last dimension will be reduced. | |
returns: a numpy array with one less dimension than input. | |
""" | |
gm = hdstats.nangeomedian_pcm(arr, **kw) | |
nt = kw.pop('num_threads', None) | |
emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
return np.concatenate([gm, emad, smad, bcmad], axis=-1) | |
def norm_input(ds, axis): | |
if isinstance(ds, xr.DataArray): | |
xx = ds | |
if len(xx.dims) != 4: | |
raise ValueError("Expect 4 dimensions on input: y,x,band,time") | |
if axis is not None and xx.dims[3] != axis: | |
raise ValueError(f"Can only reduce last dimension, expect: y,x,band,{axis}") | |
return None, xx, xx.data | |
elif isinstance(ds, xr.Dataset): | |
xx = reshape_for_geomedian(ds, axis) | |
return ds, xx, xx.data | |
else: # assume numpy or similar | |
xx_data = ds | |
if xx_data.ndim != 4: | |
raise ValueError("Expect 4 dimensions on input: y,x,band,time") | |
return None, None, xx_data | |
kw.setdefault('nocheck', False) | |
kw.setdefault('num_threads', 1) | |
kw.setdefault('eps', 1e-6) | |
ds, xx, xx_data = norm_input(ds, axis) | |
is_dask = dask.is_dask_collection(xx_data) | |
if where is not None: | |
if is_dask: | |
raise NotImplementedError("Dask version doesn't support output masking currently") | |
if where.shape != xx_data.shape[:2]: | |
raise ValueError("Shape for `where` parameter doesn't match") | |
set_nan = ~where | |
else: | |
set_nan = None | |
if is_dask: | |
if xx_data.shape[-2:] != xx_data.chunksize[-2:]: | |
xx_data = xx_data.rechunk(xx_data.chunksize[:2] + (-1, -1)) | |
data = da.map_blocks(lambda x: gm_tmad(x, **kw), | |
xx_data, | |
name=randomize('geomedian'), | |
dtype=xx_data.dtype, | |
chunks=xx_data.chunks[:-2] + (xx_data.chunks[-2][0]+3,), | |
drop_axis=3) | |
else: | |
data = gm_tmad(xx_data, **kw) | |
if set_nan is not None: | |
data[set_nan, :] = np.nan | |
if xx is None: | |
return data | |
dims = xx.dims[:-1] | |
cc = {k: xx.coords[k] for k in dims} | |
cc[dims[-1]] = np.hstack([xx.coords[dims[-1]].values,['edev', 'sdev', 'bcdev']]) | |
xx_out = xr.DataArray(data, dims=dims, coords=cc) | |
if ds is None: | |
xx_out.attrs.update(xx.attrs) | |
return xx_out | |
ds_out = xx_out.to_dataset(dim='band') | |
for b in ds.data_vars.keys(): | |
src, dst = ds[b], ds_out[b] | |
dst.attrs.update(src.attrs) | |
return assign_crs(ds_out, crs=ds.geobox.crs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment