Convert CuPy stream to Numba stream
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Many Python libraries provide low-level CUDA support, but when it comes to interoperability | |
# things get complicated. This script shows a simple example on converting a CUDA stream | |
# created from CuPy's API to one that is compatible with Numba. | |
from numba import cuda | |
import cupy as cp | |
def stream_cupy_to_numba(cp_stream): | |
''' | |
Notes: | |
1. The lifetime of the returned Numba stream should be as long as the CuPy one, | |
which handles the deallocation of the underlying CUDA stream. | |
2. The returned Numba stream is assumed to live in the same CUDA context as the | |
CuPy one. | |
3. The implementation here closely follows that of cuda.stream() in Numba. | |
''' | |
from ctypes import c_void_p | |
import weakref | |
# get the pointer to actual CUDA stream | |
raw_str = cp_stream.ptr | |
# gather necessary ingredients | |
ctx = cuda.devices.get_context() | |
handle = c_void_p(raw_str) | |
finalizer = None # let CuPy handle its lifetime, not Numba | |
# create a Numba stream | |
nb_stream = cuda.cudadrv.driver.Stream(weakref.proxy(ctx), handle, finalizer) | |
return nb_stream | |
# do some computations with CuPy to set up the context, e.g.: | |
a = cp.random.random(100) | |
cp_stream = cp.cuda.stream.Stream() | |
nb_stream = stream_cupy_to_numba(cp_stream) | |
# now nb_stream can be used for launching Numba JIT kernels, e.g.: | |
# kernel[grid, block, nb_stream](a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Updated line 28 to avoid specifying the
finalizer
, see the discussion around numba/numba#5347 (comment).