Skip to content

Instantly share code, notes, and snippets.

@kim366
Last active June 15, 2023 02:10
Show Gist options
  • Save kim366/4db6571451e282b17dad8e301ea4b1e6 to your computer and use it in GitHub Desktop.
Save kim366/4db6571451e282b17dad8e301ea4b1e6 to your computer and use it in GitHub Desktop.
def apply_chunked(data, fn, dim):
"""
Some xarray operations, such as resample disregard the fact that the loaded data set is distributed via dask.
This function applies any such operation to each chunk individually.
Arguments:
data: the xarray.Dataset
fn: any unary function acting on an xarray.Dataset
dim: the dimension to split chunks across as str
"""
# split dataset into chunks
index_bounds = np.concatenate([[0], np.cumsum(data.chunks[dim])])
indices = [slice(*bounds) for bounds in np.lib.stride_tricks.sliding_window_view(index_bounds, 2)]
chunks = [data.isel({dim: indices}) for indices in indices]
# compute on chunks
tasks = [dask.delayed(fn)(chunk) for chunk in chunks]
processed_chunks, = dask.compute(tasks)
# reconstruct
return xarray.concat(processed_chunks, dim=dim, data_vars='minimal', coords='minimal', compat='override')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment