Skip to content

Instantly share code, notes, and snippets.

@maawoo
Last active January 30, 2024 11:58
Show Gist options
  • Save maawoo/0b34d371c3cc1960a1589ccaded868c2 to your computer and use it in GitHub Desktop.
Save maawoo/0b34d371c3cc1960a1589ccaded868c2 to your computer and use it in GitHub Desktop.
xr_nanquantile - A workaround for https://github.com/pydata/xarray/issues/7377
import numpy as np
from xarray.core.utils import is_scalar
from xarray.core.computation import apply_ufunc
from numbagg import nanquantile
from typing import Optional
from xarray import DataArray
def da_nanquantiles(da: DataArray,
dim: Optional[str | Sequence[str]] = None,
quantiles: Optional[float | Sequence[float]] = None
) -> DataArray:
"""
Calculate quantiles along a given dimension of a DataArray, ignoring NaN values.
If multiple quantiles are given, the returned DataArray will have a new dimension
'quantile' with the quantile values as coordinates. If only a single quantile is
given, the quantile dimension will be dropped.
Parameters
----------
da: DataArray
The DataArray to calculate quantiles for.
dim : str or tuple of str, optional
Dimension(s) to reduce. If None (default), the 'time' dimension will be reduced.
quantiles : float or tuple of float, optional
The quantiles to calculate. If None (default), the quantiles (0.05, 0.95) will
be returned.
Returns
-------
result: DataArray
If `q` is a single quantile, then the result is a scalar. If multiple
quantiles are given, first axis of the result corresponds to the quantile and a
quantile dimension is added to the return array. The other dimensions are the
dimensions that remain after the reduction of the array.
Notes
-----
This is a simple workaround for https://github.com/pydata/xarray/issues/7377.
The structure of this function is based on:
https://github.com/pydata/xarray/blob/6c5840e1198707cdcf7dc459f27ea9510eb76388/xarray/core/variable.py#L2128-L2271
Instead of `numpy.nanquantile`, the following `nanquantile` method of the numbagg
package is implemented:
https://github.com/numbagg/numbagg/blob/50357a2545c831a911c2fa92fc7de485a2f610aa/numbagg/funcs.py#L184
"""
if dim is None:
dim = da.dims
if is_scalar(dim):
dim = [dim]
if quantiles is None:
quantiles = (0.05, 0.95)
scalar = is_scalar(quantiles)
quantiles = np.atleast_1d(np.asarray(quantiles, dtype=np.float64))
def _wrapper(x, **kwargs):
# move quantile axis to end. required for apply_ufunc
return np.moveaxis(nanquantile(x.copy(), **kwargs), source=0, destination=-1)
axis = np.arange(-1, -1 * len(dim) - 1, -1)
result = apply_ufunc(
_wrapper,
da,
input_core_dims=[dim],
exclude_dims=set(dim),
output_core_dims=[["quantile"]],
output_dtypes=[np.float64],
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(quantiles)}),
dask="parallelized",
kwargs={'quantiles': quantiles, 'axis': axis},
)
result = result.assign_coords(quantile=DataArray(quantiles, dims=("quantile",)))
result = result.transpose("quantile", ...)
if scalar:
result = result.squeeze("quantile")
result.attrs = da.attrs.copy()
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment