Skip to content

Instantly share code, notes, and snippets.

@alsrgv
Created January 30, 2020 07:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alsrgv/880c57ff79aa0ea11c383af10e634789 to your computer and use it in GitHub Desktop.
Save alsrgv/880c57ff79aa0ea11c383af10e634789 to your computer and use it in GitHub Desktop.
Numba CUDA jitclass question
import numba
import numba.cuda
import numpy as np
@numba.jitclass([
("x", numba.types.int32)
])
class XYZ:
def __init__(self, x):
self.x = x
@numba.cuda.jit
def set_x(a: np.ndarray, xyz: XYZ) -> np.ndarray:
x = numba.cuda.grid(1)
if x >= a.shape[0]:
return
a[x] = xyz.x
a = np.zeros(5)
xyz = XYZ(10)
set_x(a, xyz)
print(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment