Debugging performance bottleneck in xarray futures to dask-backed array
"import functools\n",
"import dask.array\n",
"import numpy as np\n",
"import xarray as xr\n",
"from dask import distributed as dd\n",
"def dataarrays_from_delayed(futures, client=None):\n",
" \"\"\"\n",
" Returns a list of xarray dataarrays from a list of futures of dataarrays\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.DataArray` objects.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" arrays : list\n",
" list of :py:class:`xarray.DataArray` objects with\n",
" :py:class:`dask.array.Array` backends.\n",
" Examples\n",
" --------\n",
" Given a mapped xarray DataArray, pull the metadata into memory while\n",
" leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np\n",
" >>> def build_arr(multiplier):\n",
" ... return multiplier * xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> arrs = dataarrays_from_delayed(fut)\n",
" >>> arrs[-1] # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(x: 2)>\n",
" dask.array<...shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" This list of arrays can now be manipulated using normal xarray tools:\n",
" .. code-block:: python\n",
" >>> xr.concat(arrs, dim='simulation') # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(simulation: 3, x: 2)>\n",
" dask.array<...shape=(3, 2), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Dimensions without coordinates: simulation\n",
" >>> client.close()\n",
" \"\"\"\n",
" if client is None:\n",
" client = dd.get_client()\n",
" delayed_arrays = x:, futures)\n",
" dask_array_metadata = client.gather(\n",
" x: (,, futures)\n",
" )\n",
" dask_arrays = [\n",
" dask.array.from_delayed(delayed_arrays[i], *dask_array_metadata[i])\n",
" for i in range(len(futures))\n",
" ]\n",
" # using dict(x.coords) b/c gathering coords can blow up memory for some reason\n",
" array_metadata = client.gather(\n",
" lambda x: {\n",
" \"dims\": x.dims,\n",
" \"coords\": dict(x.coords),\n",
" \"attrs\": x.attrs,\n",
" \"name\":,\n",
" },\n",
" futures,\n",
" )\n",
" )\n",
" data_arrays = [\n",
" xr.DataArray(dask_arrays[i], **array_metadata[i]) for i in range(len(futures))\n",
" ]\n",
" return data_arrays\n",
"def dataarray_from_delayed(futures, dim=None, client=None):\n",
" \"\"\"\n",
" Returns a DataArray from a list of futures\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.DataArray` objects.\n",
" dim : str, optional\n",
" dimension along which to concat :py:class:`xarray.DataArray`.\n",
" Inferred by default.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" array : object\n",
" :py:class:`xarray.DataArray` concatenated along ``dim`` with\n",
" a :py:class:`dask.array.Array` backend.\n",
" Examples\n",
" --------\n",
" Given a mapped xarray DataArray, pull the metadata into memory while\n",
" leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np, pandas as pd\n",
" >>> def build_arr(multiplier):\n",
" ... return multiplier * xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> da = dataarray_from_delayed(\n",
" ... fut,\n",
" ... dim=pd.Index(range(3), name='simulation'))\n",
" ...\n",
" >>> da # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(simulation: 3, x: 2)>\n",
" dask.array<...shape=(3, 2), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" * simulation (simulation) int64 0 1 2\n",
" >>> client.close()\n",
" \"\"\"\n",
" data_arrays = dataarrays_from_delayed(futures, client=client)\n",
" da = xr.concat(data_arrays, dim=dim)\n",
" return da\n",
"def datasets_from_delayed(futures, client=None):\n",
" \"\"\"\n",
" Returns a list of xarray datasets from a list of futures of datasets\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.Dataset` objects.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" datasets : list\n",
" list of :py:class:`xarray.Dataset` objects with\n",
" :py:class:`dask.array.Array` backends for each variable.\n",
" Examples\n",
" --------\n",
" Given a mapped :py:class:`xarray.Dataset`, pull the metadata into memory\n",
" while leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np\n",
" >>> def build_ds(multiplier):\n",
" ... return multiplier * xr.Dataset({\n",
" ... 'var1': xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])})\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> arrs = datasets_from_delayed(fut)\n",
" >>> arrs[-1] # doctest: +ELLIPSIS\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Data variables:\n",
" var1 (x) int64 dask.array<chunksize=(2,), meta=np.ndarray>\n",
" This list of arrays can now be manipulated using normal xarray tools:\n",
" .. code-block:: python\n",
" >>> xr.concat(arrs, dim='y') # doctest: +ELLIPSIS\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2, y: 3)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Dimensions without coordinates: y\n",
" Data variables:\n",
" var1 (y, x) int64 dask.array<chunksize=(1, 2), meta=np.ndarray>\n",
" >>> client.close()\n",
" \"\"\"\n",
" if client is None:\n",
" client = dd.get_client()\n",
" data_var_keys = client.gather(\n",
" x: list(x.data_vars.keys()), futures)\n",
" )\n",
" delayed_arrays = [\n",
" {k: (client.submit(lambda x: x[k].data, futures[i])) for k in data_var_keys[i]}\n",
" for i in range(len(futures))\n",
" ]\n",
" dask_array_metadata = [\n",
" {\n",
" k: (\n",
" client.submit(\n",
" lambda x: (x[k].data.shape, x[k].data.dtype), futures[i]\n",
" ).result()\n",
" )\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" dask_data_arrays = [\n",
" {\n",
" k: (\n",
" dask.array.from_delayed(\n",
" delayed_arrays[i][k], *dask_array_metadata[i][k]\n",
" )\n",
" )\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" # using dict(x.coords) b/c gathering coords can blow up memory for some reason\n",
" array_metadata = [\n",
" {\n",
" k: client.submit(\n",
" lambda x: {\n",
" \"dims\": x[k].dims,\n",
" \"coords\": dict(x[k].coords),\n",
" \"attrs\": x[k].attrs,\n",
" },\n",
" futures[i],\n",
" ).result()\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" data_arrays = [\n",
" {\n",
" k: (xr.DataArray(dask_data_arrays[i][k], **array_metadata[i][k]))\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" datasets = [xr.Dataset(arr) for arr in data_arrays]\n",
" dataset_metadata = client.gather( x: x.attrs, futures))\n",
" for i in range(len(futures)):\n",
" datasets[i].attrs.update(dataset_metadata[i])\n",
" return datasets\n",
"def dataset_from_delayed(futures, dim=None, client=None):\n",
" \"\"\"\n",
" Returns an :py:class:`xarray.Dataset` from a list of futures\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.Dataset` objects.\n",
" dim : str, optional\n",
" dimension along which to concat :py:class:`xarray.Dataset`.\n",
" Inferred by default.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" dataset : object\n",
" :py:class:`xarray.Dataset` concatenated along ``dim`` with\n",
" :py:class:`dask.array.Array` backends for each variable.\n",
" Examples\n",
" --------\n",
" Given a mapped :py:class:`xarray.Dataset`, pull the metadata into memory\n",
" while leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np, pandas as pd\n",
" >>> def build_ds(multiplier):\n",
" ... return multiplier * xr.Dataset({\n",
" ... 'var1': xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])})\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> ds = dataset_from_delayed(fut, dim=pd.Index(range(3), name='y'))\n",
" >>> ds\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2, y: 3)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" * y (y) int64 0 1 2\n",
" Data variables:\n",
" var1 (y, x) int64 dask.array<chunksize=(1, 2), meta=np.ndarray>\n",
" >>> client.close()\n",
" \"\"\"\n",
" datasets = datasets_from_delayed(futures, client=client)\n",
" ds = xr.concat(datasets, dim=dim)\n",
" return ds\n"
"import functools\n",
"import dask.array\n",
"import numpy as np\n",
"import xarray as xr\n",
"from dask import distributed as dd\n",
"def dataarrays_from_delayed_old(futures, client=None):\n",
" \"\"\"\n",
" Returns a list of xarray dataarrays from a list of futures of dataarrays\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.DataArray` objects.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" arrays : list\n",
" list of :py:class:`xarray.DataArray` objects with\n",
" :py:class:`dask.array.Array` backends.\n",
" Examples\n",
" --------\n",
" Given a mapped xarray DataArray, pull the metadata into memory while\n",
" leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np\n",
" >>> def build_arr(multiplier):\n",
" ... return multiplier * xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> arrs = dataarrays_from_delayed(fut)\n",
" >>> arrs[-1] # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(x: 2)>\n",
" dask.array<...shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" This list of arrays can now be manipulated using normal xarray tools:\n",
" .. code-block:: python\n",
" >>> xr.concat(arrs, dim='simulation') # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(simulation: 3, x: 2)>\n",
" dask.array<...shape=(3, 2), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Dimensions without coordinates: simulation\n",
" >>> client.close()\n",
" \"\"\"\n",
" if client is None:\n",
" client = dd.get_client()\n",
" delayed_arrays = x:, futures)\n",
" dask_array_metadata = client.gather(\n",
" x: (,, futures)\n",
" )\n",
" dask_arrays = [\n",
" dask.array.from_delayed(delayed_arrays[i], *dask_array_metadata[i])\n",
" for i in range(len(futures))\n",
" ]\n",
" # using dict(x.coords) b/c gathering coords can blow up memory for some reason\n",
" array_metadata = client.gather(\n",
" lambda x: {\n",
" \"dims\": x.dims,\n",
" \"coords\": x.coords,\n",
" \"attrs\": x.attrs,\n",
" \"name\":,\n",
" },\n",
" futures,\n",
" )\n",
" )\n",
" data_arrays = [\n",
" xr.DataArray(dask_arrays[i], **array_metadata[i]) for i in range(len(futures))\n",
" ]\n",
" return data_arrays\n",
"def dataarray_from_delayed_old(futures, dim=None, client=None):\n",
" \"\"\"\n",
" Returns a DataArray from a list of futures\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.DataArray` objects.\n",
" dim : str, optional\n",
" dimension along which to concat :py:class:`xarray.DataArray`.\n",
" Inferred by default.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" array : object\n",
" :py:class:`xarray.DataArray` concatenated along ``dim`` with\n",
" a :py:class:`dask.array.Array` backend.\n",
" Examples\n",
" --------\n",
" Given a mapped xarray DataArray, pull the metadata into memory while\n",
" leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np, pandas as pd\n",
" >>> def build_arr(multiplier):\n",
" ... return multiplier * xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> da = dataarray_from_delayed(\n",
" ... fut,\n",
" ... dim=pd.Index(range(3), name='simulation'))\n",
" ...\n",
" >>> da # doctest: +ELLIPSIS\n",
" <xarray.DataArray ...(simulation: 3, x: 2)>\n",
" dask.array<...shape=(3, 2), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" * simulation (simulation) int64 0 1 2\n",
" >>> client.close()\n",
" \"\"\"\n",
" data_arrays = dataarrays_from_delayed_old(futures, client=client)\n",
" da = xr.concat(data_arrays, dim=dim)\n",
" return da\n",
"def datasets_from_delayed_old(futures, client=None):\n",
" \"\"\"\n",
" Returns a list of xarray datasets from a list of futures of datasets\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.Dataset` objects.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" datasets : list\n",
" list of :py:class:`xarray.Dataset` objects with\n",
" :py:class:`dask.array.Array` backends for each variable.\n",
" Examples\n",
" --------\n",
" Given a mapped :py:class:`xarray.Dataset`, pull the metadata into memory\n",
" while leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np\n",
" >>> def build_ds(multiplier):\n",
" ... return multiplier * xr.Dataset({\n",
" ... 'var1': xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])})\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> arrs = datasets_from_delayed(fut)\n",
" >>> arrs[-1] # doctest: +ELLIPSIS\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Data variables:\n",
" var1 (x) int64 dask.array<chunksize=(2,), meta=np.ndarray>\n",
" This list of arrays can now be manipulated using normal xarray tools:\n",
" .. code-block:: python\n",
" >>> xr.concat(arrs, dim='y') # doctest: +ELLIPSIS\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2, y: 3)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" Dimensions without coordinates: y\n",
" Data variables:\n",
" var1 (y, x) int64 dask.array<chunksize=(1, 2), meta=np.ndarray>\n",
" >>> client.close()\n",
" \"\"\"\n",
" if client is None:\n",
" client = dd.get_client()\n",
" data_var_keys = client.gather(\n",
" x: list(x.data_vars.keys()), futures)\n",
" )\n",
" delayed_arrays = [\n",
" {k: (client.submit(lambda x: x[k].data, futures[i])) for k in data_var_keys[i]}\n",
" for i in range(len(futures))\n",
" ]\n",
" dask_array_metadata = [\n",
" {\n",
" k: (\n",
" client.submit(\n",
" lambda x: (x[k].data.shape, x[k].data.dtype), futures[i]\n",
" ).result()\n",
" )\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" dask_data_arrays = [\n",
" {\n",
" k: (\n",
" dask.array.from_delayed(\n",
" delayed_arrays[i][k], *dask_array_metadata[i][k]\n",
" )\n",
" )\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" # using dict(x.coords) b/c gathering coords can blow up memory for some reason\n",
" array_metadata = [\n",
" {\n",
" k: client.submit(\n",
" lambda x: {\n",
" \"dims\": x[k].dims,\n",
" \"coords\": x[k].coords,\n",
" \"attrs\": x[k].attrs,\n",
" },\n",
" futures[i],\n",
" ).result()\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" data_arrays = [\n",
" {\n",
" k: (xr.DataArray(dask_data_arrays[i][k], **array_metadata[i][k]))\n",
" for k in data_var_keys[i]\n",
" }\n",
" for i in range(len(futures))\n",
" ]\n",
" datasets = [xr.Dataset(arr) for arr in data_arrays]\n",
" dataset_metadata = client.gather( x: x.attrs, futures))\n",
" for i in range(len(futures)):\n",
" datasets[i].attrs.update(dataset_metadata[i])\n",
" return datasets\n",
"def dataset_from_delayed_old(futures, dim=None, client=None):\n",
" \"\"\"\n",
" Returns an :py:class:`xarray.Dataset` from a list of futures\n",
" Parameters\n",
" ----------\n",
" futures : list\n",
" list of :py:class:`dask.delayed.Future` objects holding\n",
" :py:class:`xarray.Dataset` objects.\n",
" dim : str, optional\n",
" dimension along which to concat :py:class:`xarray.Dataset`.\n",
" Inferred by default.\n",
" client : object, optional\n",
" :py:class:`dask.distributed.Client` to use in gathering\n",
" metadata on futures. If not provided, client is inferred\n",
" from context.\n",
" Returns\n",
" -------\n",
" dataset : object\n",
" :py:class:`xarray.Dataset` concatenated along ``dim`` with\n",
" :py:class:`dask.array.Array` backends for each variable.\n",
" Examples\n",
" --------\n",
" Given a mapped :py:class:`xarray.Dataset`, pull the metadata into memory\n",
" while leaving the data on the workers:\n",
" .. code-block:: python\n",
" >>> import numpy as np, pandas as pd\n",
" >>> def build_ds(multiplier):\n",
" ... return multiplier * xr.Dataset({\n",
" ... 'var1': xr.DataArray(\n",
" ... np.arange(2), dims=['x'], coords=[['a', 'b']])})\n",
" ...\n",
" >>> client = dd.Client()\n",
" >>> fut =, range(3))\n",
" >>> ds = dataset_from_delayed(fut, dim=pd.Index(range(3), name='y'))\n",
" >>> ds\n",
" <xarray.Dataset>\n",
" Dimensions: (x: 2, y: 3)\n",
" Coordinates:\n",
" * x (x) <U1 'a' 'b'\n",
" * y (y) int64 0 1 2\n",
" Data variables:\n",
" var1 (y, x) int64 dask.array<chunksize=(1, 2), meta=np.ndarray>\n",
" >>> client.close()\n",
" \"\"\"\n",
" datasets = datasets_from_delayed_old(futures, client=client)\n",
" ds = xr.concat(datasets, dim=dim)\n",
" return ds\n"
"cell_type": "code",
