Skip to content

Instantly share code, notes, and snippets.

@mathause
Last active February 10, 2021 15:14
Show Gist options
  • Save mathause/3c317d9991ce5b8cc1eb51ba9d9d7353 to your computer and use it in GitHub Desktop.
Save mathause/3c317d9991ce5b8cc1eb51ba9d9d7353 to your computer and use it in GitHub Desktop.
import scipy as sp
import xarray as xr
ds = xr.tutorial.open_dataset('air_temperature').load()
ds = ds.air
xr.__version__
def mannwhitney_(v1, v2):
# return NaN for all-nan vectors
if np.isnan(v1).all() or np.isnan(v2).all():
return np.NaN
# use masked-stats if any is nan
if np.isnan(v1).any() or np.isnan(v2).any():
v1 = np.ma.masked_invalid(v1)
v2 = np.ma.masked_invalid(v2)
_, p_val = sp.stats.mstats.mannwhitneyu(v1, v2)
else:
_, p_val = sp.stats.mannwhitneyu(v1, v2)
return p_val
dim = ["time"]
def mw(d):
return xr.apply_ufunc(
mannwhitney_,
d,
d,
input_core_dims=[dim, dim],
output_core_dims=[[]],
exclude_dims=set(dim),
vectorize=True,
dask="parallelized",
output_dtypes=[float],
)
dsc = ds.chunk({"lat": 10})
d = mw(dsc)
%timeit d.compute()
# 907 ms
dsc = ds.chunk({"lon": 10})
d = mw(dsc)
%timeit d.compute()
# 891 ms
dsc = ds.chunk({"lat": 12, "lon": 25})
d = mw(dsc)
%timeit d.compute()
# 853 ms
%timeit mw(ds)
# 1.26 s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment