-
-
Save sklam/e5496e412fccac6acc0e96b4413ed977 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba.core.extending import intrinsic | |
from numba.core import types, cgutils | |
from numba import njit, prange | |
from numba.np.arrayobj import make_array, basic_indexing, normalize_indices | |
from llvmlite import ir | |
@intrinsic | |
def atomic_array_add(tyctx, arr, idx, val): | |
def codegen(context, builder, signature, args): | |
[arr, idx, val] = args | |
[arr_ty, idx_ty, val_ty] = signature.args | |
llary = make_array(arr_ty)(context, builder, arr) | |
index_types, indices = normalize_indices( | |
context, builder, [idx_ty], [idx], | |
) | |
view_data, view_shapes, view_strides = basic_indexing( | |
context, builder, arr_ty, llary, index_types, indices, | |
boundscheck=context.enable_boundscheck, | |
) | |
cgutils.printf(builder, "data=%p\n", view_data) | |
out = builder.atomic_rmw("add", view_data, val, ordering="seq_cst") | |
return out | |
resty = arr.dtype | |
sig = resty(arr, idx, val) | |
return sig, codegen | |
@njit(parallel=True) | |
def foo(arr): | |
accum = np.zeros(1, dtype=arr.dtype) | |
for i in prange(arr.size): | |
out = atomic_array_add(accum, 0, arr[i]) | |
# print(i, out) | |
return accum[0] | |
total = foo(np.arange(10, dtype=np.intp)) | |
print(total) | |
foo.inspect_cfg(foo.signatures[0]).display(view=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment