-
-
Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
import numpy as np | |
from numba import types, njit, cgutils, typing | |
from numba.extending import intrinsic | |
@intrinsic | |
def index_tuple_for_array(tyctx, ary, idxary): | |
""" | |
Implement an intrinsic to return a fix-sized tuple given the target array | |
and the index-array. | |
Equivalent to `tuple(idxary)` | |
Worksaround the lack of dynamic-sized-tuple | |
""" | |
# This is the typing level. | |
# We'll setup the type and constant information here | |
nd = ary.ndim # size of the tuple | |
tupty = types.UniTuple(dtype=types.intp, count=nd) # the tuple | |
funtion_sig = tupty(ary, idxary) # the function signature for this intrinsic | |
def codegen(cgctx, builder, signature, args): | |
# This is the implementation defined using LLVM builder | |
lltupty = cgctx.get_value_type(tupty) | |
tup = cgutils.get_null_value(lltupty) | |
[_, idxaryval] = args | |
def array_checker(a): | |
if a.size != nd: | |
raise IndexError("index array size mismatch") | |
# compile and call array_checker | |
cgctx.compile_internal(builder, array_checker, types.none(idxary), | |
[idxaryval]) | |
def array_indexer(a, i): | |
return a[i] | |
# loop to fill the tuple | |
for i in range(nd): | |
dataidx = cgctx.get_constant(types.intp, i) | |
# compile and call array_indexer | |
data = cgctx.compile_internal(builder, array_indexer, | |
idxary.dtype(idxary, types.intp), | |
[idxaryval, dataidx]) | |
tup = builder.insert_value(tup, data, i) | |
return tup | |
return funtion_sig, codegen | |
def test(): | |
@njit | |
def foo(ary, idx): | |
return ary[index_tuple_for_array(ary, idx)] | |
ary = np.arange(10, dtype=np.intp).reshape(5, 2) | |
idx = np.asarray([2, 1], dtype=np.intp) | |
r = foo(ary, idx) | |
print(r) | |
print(ary[tuple(idx)]) | |
if __name__ == '__main__': | |
test() |
Hi! This is great but i cant find a way to pass nd
as an argument to index_tuple_for_array from the outside njit function? I am trying to simplify it a bit like this:
@numba.extending.intrinsic
def array_to_tuple_1d(tyctx, ary, nd):
"""
Similar to tuple(ary)
"""
# This is the typing level.
# We'll setup the type and constant information here
tupty = numba.types.UniTuple(dtype=numba.types.int32, count=nd) # the tuple
funtion_sig = tupty(ary) # the function signature for this intrinsic
def codegen(cgctx, builder, signature, args):
# This is the implementation defined using LLVM builder
lltupty = cgctx.get_value_type(tupty)
tup = numba.cgutils.get_null_value(lltupty)
[aryval] = args
def array_indexer(a, i):
return a[i]
# loop to fill the tuple
for i in range(nd):
dataidx = cgctx.get_constant(numba.types.intp, i)
# compile and call array_indexer
data = cgctx.compile_internal(builder, array_indexer,
ary.dtype(ary, numba.types.intp),
[aryval, dataidx])
tup = builder.insert_value(tup, data, i)
return tup
return funtion_sig, codegen
@numba.njit
def arr2tup(arr):
return array_to_tuple_1d(arr, arr.size)
arr2tup(np.array([1, 2, 3]))
and i get TypeError: %d format: a number is required, not Integer
, tried some workarounds but didnt succeed.
At def array_to_tuple_1d(tyctx, ary, nd):
, nd
is a type (Integer
) and not an actual value. In numba.types.UniTuple(dtype=numba.types.int32, count=nd)
, count
should be a value.
Thank you, i see. Can you please take a quick look at related issue numba/numba#4265? I am looking for the way to explicitly pass the value (if it is possible).
@sklam, there is a function to_fixed_tuple in numba.unsafe.ndarray. It works pretty much similar and in does accept literals:
import numba
from numba.unsafe.ndarray import to_fixed_tuple
arr = np.array([1, 2, 3, 4, 5])
arr_length = len(arr)
@numba.njit
def foo(arr):
l = 5
return to_fixed_tuple(arr, l)
foo(arr) # (1, 2, 3, 4, 5)
But i cant find away to pass arr
length into it.
This is much easier to use than "to_fixed_tuple" from numba.unsafe.ndarray.
I use this function to access an array. The dimension of the array is depend on input data and can not be decided at compile time.
Without help of this function, I can only access by create a 1-dim array of length X * Y * Z, and compute the index by a[i * Y * Z+j*Z+k]
I wonder why not merge this numba?