Skip to content

Instantly share code, notes, and snippets.

@jakirkham
Last active February 24, 2020 23:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jakirkham/8a95d6a04c75342d8b89e82fe130407d to your computer and use it in GitHub Desktop.
Save jakirkham/8a95d6a04c75342d8b89e82fe130407d to your computer and use it in GitHub Desktop.
Attempt at repro for multi-array return
import sys
import dask
import dask.array
from dask.delayed import delayed
import distributed
try:
from dask_cuda import LocalCUDACluster as DaskCluster
client_kwargs = dict(protocol="ucx")
except ImportError:
from distributed import LocalCluster as DaskCluster
client_kwargs = dict(protocol="tcp")
try:
import cupy as xnumpy
except ImportError:
import numpy as xnumpy
@delayed
def double_halve(x):
return 2 * x, x / 2
def get_client():
cluster = DaskCluster()
client = distributed.Client(cluster, **client_kwargs)
return client
def main(*argv):
client = get_client()
rs = dask.array.random.RandomState(RandomState=xnumpy.random.RandomState)
a = rs.random((20,), chunks=(5,))
arrs = [double_halve(e) for e in a.blocks]
futures = client.compute(arrs)
results = client.gather(futures)
print("")
print(results)
print("")
return 0
if __name__ == "__main__":
sys.exit(main(*sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment