-
-
Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
numba intrinsic for creating fixed-size tuple from an array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import types, njit, cgutils, typing | |
from numba.extending import intrinsic | |
@intrinsic | |
def index_tuple_for_array(tyctx, ary, idxary): | |
""" | |
Implement an intrinsic to return a fix-sized tuple given the target array | |
and the index-array. | |
Equivalent to `tuple(idxary)` | |
Worksaround the lack of dynamic-sized-tuple | |
""" | |
# This is the typing level. | |
# We'll setup the type and constant information here | |
nd = ary.ndim # size of the tuple | |
tupty = types.UniTuple(dtype=types.intp, count=nd) # the tuple | |
funtion_sig = tupty(ary, idxary) # the function signature for this intrinsic | |
def codegen(cgctx, builder, signature, args): | |
# This is the implementation defined using LLVM builder | |
lltupty = cgctx.get_value_type(tupty) | |
tup = cgutils.get_null_value(lltupty) | |
[_, idxaryval] = args | |
def array_checker(a): | |
if a.size != nd: | |
raise IndexError("index array size mismatch") | |
# compile and call array_checker | |
cgctx.compile_internal(builder, array_checker, types.none(idxary), | |
[idxaryval]) | |
def array_indexer(a, i): | |
return a[i] | |
# loop to fill the tuple | |
for i in range(nd): | |
dataidx = cgctx.get_constant(types.intp, i) | |
# compile and call array_indexer | |
data = cgctx.compile_internal(builder, array_indexer, | |
idxary.dtype(idxary, types.intp), | |
[idxaryval, dataidx]) | |
tup = builder.insert_value(tup, data, i) | |
return tup | |
return funtion_sig, codegen | |
def test(): | |
@njit | |
def foo(ary, idx): | |
return ary[index_tuple_for_array(ary, idx)] | |
ary = np.arange(10, dtype=np.intp).reshape(5, 2) | |
idx = np.asarray([2, 1], dtype=np.intp) | |
r = foo(ary, idx) | |
print(r) | |
print(ary[tuple(idx)]) | |
if __name__ == '__main__': | |
test() |
Thank you, i see. Can you please take a quick look at related issue numba/numba#4265? I am looking for the way to explicitly pass the value (if it is possible).
@sklam, there is a function to_fixed_tuple in numba.unsafe.ndarray. It works pretty much similar and in does accept literals:
import numba
from numba.unsafe.ndarray import to_fixed_tuple
arr = np.array([1, 2, 3, 4, 5])
arr_length = len(arr)
@numba.njit
def foo(arr):
l = 5
return to_fixed_tuple(arr, l)
foo(arr) # (1, 2, 3, 4, 5)
But i cant find away to pass arr
length into it.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
At
def array_to_tuple_1d(tyctx, ary, nd):
,nd
is a type (Integer
) and not an actual value. Innumba.types.UniTuple(dtype=numba.types.int32, count=nd)
,count
should be a value.