Last active
January 30, 2024 11:58
-
-
Save maawoo/0b34d371c3cc1960a1589ccaded868c2 to your computer and use it in GitHub Desktop.
xr_nanquantile - A workaround for https://github.com/pydata/xarray/issues/7377
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from xarray.core.utils import is_scalar | |
from xarray.core.computation import apply_ufunc | |
from numbagg import nanquantile | |
from typing import Optional | |
from xarray import DataArray | |
def da_nanquantiles(da: DataArray, | |
dim: Optional[str | Sequence[str]] = None, | |
quantiles: Optional[float | Sequence[float]] = None | |
) -> DataArray: | |
""" | |
Calculate quantiles along a given dimension of a DataArray, ignoring NaN values. | |
If multiple quantiles are given, the returned DataArray will have a new dimension | |
'quantile' with the quantile values as coordinates. If only a single quantile is | |
given, the quantile dimension will be dropped. | |
Parameters | |
---------- | |
da: DataArray | |
The DataArray to calculate quantiles for. | |
dim : str or tuple of str, optional | |
Dimension(s) to reduce. If None (default), the 'time' dimension will be reduced. | |
quantiles : float or tuple of float, optional | |
The quantiles to calculate. If None (default), the quantiles (0.05, 0.95) will | |
be returned. | |
Returns | |
------- | |
result: DataArray | |
If `q` is a single quantile, then the result is a scalar. If multiple | |
quantiles are given, first axis of the result corresponds to the quantile and a | |
quantile dimension is added to the return array. The other dimensions are the | |
dimensions that remain after the reduction of the array. | |
Notes | |
----- | |
This is a simple workaround for https://github.com/pydata/xarray/issues/7377. | |
The structure of this function is based on: | |
https://github.com/pydata/xarray/blob/6c5840e1198707cdcf7dc459f27ea9510eb76388/xarray/core/variable.py#L2128-L2271 | |
Instead of `numpy.nanquantile`, the following `nanquantile` method of the numbagg | |
package is implemented: | |
https://github.com/numbagg/numbagg/blob/50357a2545c831a911c2fa92fc7de485a2f610aa/numbagg/funcs.py#L184 | |
""" | |
if dim is None: | |
dim = da.dims | |
if is_scalar(dim): | |
dim = [dim] | |
if quantiles is None: | |
quantiles = (0.05, 0.95) | |
scalar = is_scalar(quantiles) | |
quantiles = np.atleast_1d(np.asarray(quantiles, dtype=np.float64)) | |
def _wrapper(x, **kwargs): | |
# move quantile axis to end. required for apply_ufunc | |
return np.moveaxis(nanquantile(x.copy(), **kwargs), source=0, destination=-1) | |
axis = np.arange(-1, -1 * len(dim) - 1, -1) | |
result = apply_ufunc( | |
_wrapper, | |
da, | |
input_core_dims=[dim], | |
exclude_dims=set(dim), | |
output_core_dims=[["quantile"]], | |
output_dtypes=[np.float64], | |
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(quantiles)}), | |
dask="parallelized", | |
kwargs={'quantiles': quantiles, 'axis': axis}, | |
) | |
result = result.assign_coords(quantile=DataArray(quantiles, dims=("quantile",))) | |
result = result.transpose("quantile", ...) | |
if scalar: | |
result = result.squeeze("quantile") | |
result.attrs = da.attrs.copy() | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment