Skip to content

Instantly share code, notes, and snippets.

@cbur24
Last active February 9, 2024 01:32
Show Gist options
  • Save cbur24/5f7ff53c3cd144094ee851299290a990 to your computer and use it in GitHub Desktop.
Save cbur24/5f7ff53c3cd144094ee851299290a990 to your computer and use it in GitHub Desktop.
Same as "xr_geomad.py" but optimized for dask
import xarray as xr
import numpy as np
import dask
import hdstats
from odc.algo import randomize
from odc.algo._dask import reshape_yxbt
from odc.geo.xr import assign_crs
def xr_geomad_dask(ds, **kw):
"""
Same as other xr_geomad but uses reshape_yxbt instead of
reshape_for_geomedian for better dask performance.
Will not work on numpy arrays
"""
def gm_tmad(arr, **kw):
"""
arr: a high dimensional numpy array where the last dimension will be reduced.
returns: a numpy array with one less dimension than input.
"""
gm = hdstats.nangeomedian_pcm(arr, **kw)
nt = kw.pop('num_threads', None)
emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis]
return np.concatenate([gm, emad, smad, bcmad], axis=-1)
def norm_input(ds):
if isinstance(ds, xr.Dataset):
xx = reshape_yxbt(ds, yx_chunks=500)
return ds, xx, xx.data
kw.setdefault('nocheck', False)
kw.setdefault('num_threads', 1)
kw.setdefault('eps', 1e-6)
ds, xx, xx_data = norm_input(ds)
is_dask = dask.is_dask_collection(xx_data)
if is_dask:
data = da.map_blocks(lambda x: gm_tmad(x, **kw),
xx_data,
name=randomize('geomedian'),
dtype=xx_data.dtype,
chunks=xx_data.chunks[:-2] + (xx_data.chunks[-2][0]+3,),
drop_axis=3)
dims = xx.dims[:-1]
cc = {k: xx.coords[k] for k in dims}
cc[dims[-1]] = np.hstack([xx.coords[dims[-1]].values,['edev', 'sdev', 'bcdev']])
xx_out = xr.DataArray(data, dims=dims, coords=cc)
if ds is None:
xx_out.attrs.update(xx.attrs)
return xx_out
ds_out = xx_out.to_dataset(dim='band')
for b in ds.data_vars.keys():
src, dst = ds[b], ds_out[b]
dst.attrs.update(src.attrs)
return assign_crs(ds_out, crs=ds.geobox.crs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment