Skip to content

Instantly share code, notes, and snippets.

@sklam
Created April 12, 2021 16:37
Show Gist options
  • Save sklam/e5496e412fccac6acc0e96b4413ed977 to your computer and use it in GitHub Desktop.
Save sklam/e5496e412fccac6acc0e96b4413ed977 to your computer and use it in GitHub Desktop.
import numpy as np
from numba.core.extending import intrinsic
from numba.core import types, cgutils
from numba import njit, prange
from numba.np.arrayobj import make_array, basic_indexing, normalize_indices
from llvmlite import ir
@intrinsic
def atomic_array_add(tyctx, arr, idx, val):
def codegen(context, builder, signature, args):
[arr, idx, val] = args
[arr_ty, idx_ty, val_ty] = signature.args
llary = make_array(arr_ty)(context, builder, arr)
index_types, indices = normalize_indices(
context, builder, [idx_ty], [idx],
)
view_data, view_shapes, view_strides = basic_indexing(
context, builder, arr_ty, llary, index_types, indices,
boundscheck=context.enable_boundscheck,
)
cgutils.printf(builder, "data=%p\n", view_data)
out = builder.atomic_rmw("add", view_data, val, ordering="seq_cst")
return out
resty = arr.dtype
sig = resty(arr, idx, val)
return sig, codegen
@njit(parallel=True)
def foo(arr):
accum = np.zeros(1, dtype=arr.dtype)
for i in prange(arr.size):
out = atomic_array_add(accum, 0, arr[i])
# print(i, out)
return accum[0]
total = foo(np.arange(10, dtype=np.intp))
print(total)
foo.inspect_cfg(foo.signatures[0]).display(view=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment