Skip to content

Instantly share code, notes, and snippets.

@pentschev
Created June 30, 2019 18:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pentschev/d66014ec27b331a74a15f7c48dca0121 to your computer and use it in GitHub Desktop.
Save pentschev/d66014ec27b331a74a15f7c48dca0121 to your computer and use it in GitHub Desktop.
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from dask.array.utils import assert_eq
import dask.array as da
import cupy as cp
add_broadcast_kernel = cp.RawKernel(
r'''
extern "C" __global__
void add_broadcast_kernel(const float* x1, const float* x2, float* y, const int dim0)
{
int idx0 = blockIdx.x * blockDim.x + threadIdx.x;
int idx1 = blockIdx.y * blockDim.y + threadIdx.y;
y[idx1 * dim0 + idx0] = x1[idx1 * dim0 + idx0] + x2[idx0];
}
''',
'add_broadcast_kernel'
)
def dispatch_add_broadcast(x1, x2, y):
block_size = (32, 32)
grid_size = (x1.shape[1] // block_size[1], x1.shape[0] // block_size[0])
add_broadcast_kernel(grid_size, block_size, (x1, x2, y, x1.strides[0] // x1.strides[1]))
return y
if __name__ == "__main__":
cluster = LocalCUDACluster()
client = Client(cluster)
x1 = cp.arange(4096 * 1024, dtype=cp.float32).reshape((4096, 1024))
x2 = cp.arange(1024, dtype=cp.float32).reshape(1, 1024)
y = cp.zeros((4096, 1024), dtype=cp.float32)
res_cupy = x1 + x2
res_add_broadcast = dispatch_add_broadcast(x1, x2, cp.zeros((4096, 1024), dtype=cp.float32))
assert_eq(res_cupy, res_add_broadcast)
d_x1 = da.from_array(x1, chunks=(1024, 512), asarray=False)
d_x2 = da.from_array(x2, chunks=(1, 512), asarray=False)
d_y = da.from_array(y, chunks=(1024, 512), asarray=False)
res = da.map_blocks(lambda x1, x2, y:
dispatch_add_broadcast(x1, x2, y), d_x1, d_x2, d_y, dtype=cp.float32)
assert_eq(res, res_cupy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment