Last active
February 9, 2024 01:32
-
-
Save cbur24/5f7ff53c3cd144094ee851299290a990 to your computer and use it in GitHub Desktop.
Same as "xr_geomad.py" but optimized for dask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
import dask | |
import hdstats | |
from odc.algo import randomize | |
from odc.algo._dask import reshape_yxbt | |
from odc.geo.xr import assign_crs | |
def xr_geomad_dask(ds, **kw): | |
""" | |
Same as other xr_geomad but uses reshape_yxbt instead of | |
reshape_for_geomedian for better dask performance. | |
Will not work on numpy arrays | |
""" | |
def gm_tmad(arr, **kw): | |
""" | |
arr: a high dimensional numpy array where the last dimension will be reduced. | |
returns: a numpy array with one less dimension than input. | |
""" | |
gm = hdstats.nangeomedian_pcm(arr, **kw) | |
nt = kw.pop('num_threads', None) | |
emad = hdstats.emad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
smad = hdstats.smad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
bcmad = hdstats.bcmad_pcm(arr, gm, num_threads=nt)[:,:, np.newaxis] | |
return np.concatenate([gm, emad, smad, bcmad], axis=-1) | |
def norm_input(ds): | |
if isinstance(ds, xr.Dataset): | |
xx = reshape_yxbt(ds, yx_chunks=500) | |
return ds, xx, xx.data | |
kw.setdefault('nocheck', False) | |
kw.setdefault('num_threads', 1) | |
kw.setdefault('eps', 1e-6) | |
ds, xx, xx_data = norm_input(ds) | |
is_dask = dask.is_dask_collection(xx_data) | |
if is_dask: | |
data = da.map_blocks(lambda x: gm_tmad(x, **kw), | |
xx_data, | |
name=randomize('geomedian'), | |
dtype=xx_data.dtype, | |
chunks=xx_data.chunks[:-2] + (xx_data.chunks[-2][0]+3,), | |
drop_axis=3) | |
dims = xx.dims[:-1] | |
cc = {k: xx.coords[k] for k in dims} | |
cc[dims[-1]] = np.hstack([xx.coords[dims[-1]].values,['edev', 'sdev', 'bcdev']]) | |
xx_out = xr.DataArray(data, dims=dims, coords=cc) | |
if ds is None: | |
xx_out.attrs.update(xx.attrs) | |
return xx_out | |
ds_out = xx_out.to_dataset(dim='band') | |
for b in ds.data_vars.keys(): | |
src, dst = ds[b], ds_out[b] | |
dst.attrs.update(src.attrs) | |
return assign_crs(ds_out, crs=ds.geobox.crs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment