Skip to content

Instantly share code, notes, and snippets.

@sklam
Created March 6, 2017 18:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
numba intrinsic for creating fixed-size tuple from an array
import numpy as np
from numba import types, njit, cgutils, typing
from numba.extending import intrinsic
@intrinsic
def index_tuple_for_array(tyctx, ary, idxary):
"""
Implement an intrinsic to return a fix-sized tuple given the target array
and the index-array.
Equivalent to `tuple(idxary)`
Worksaround the lack of dynamic-sized-tuple
"""
# This is the typing level.
# We'll setup the type and constant information here
nd = ary.ndim # size of the tuple
tupty = types.UniTuple(dtype=types.intp, count=nd) # the tuple
funtion_sig = tupty(ary, idxary) # the function signature for this intrinsic
def codegen(cgctx, builder, signature, args):
# This is the implementation defined using LLVM builder
lltupty = cgctx.get_value_type(tupty)
tup = cgutils.get_null_value(lltupty)
[_, idxaryval] = args
def array_checker(a):
if a.size != nd:
raise IndexError("index array size mismatch")
# compile and call array_checker
cgctx.compile_internal(builder, array_checker, types.none(idxary),
[idxaryval])
def array_indexer(a, i):
return a[i]
# loop to fill the tuple
for i in range(nd):
dataidx = cgctx.get_constant(types.intp, i)
# compile and call array_indexer
data = cgctx.compile_internal(builder, array_indexer,
idxary.dtype(idxary, types.intp),
[idxaryval, dataidx])
tup = builder.insert_value(tup, data, i)
return tup
return funtion_sig, codegen
def test():
@njit
def foo(ary, idx):
return ary[index_tuple_for_array(ary, idx)]
ary = np.arange(10, dtype=np.intp).reshape(5, 2)
idx = np.asarray([2, 1], dtype=np.intp)
r = foo(ary, idx)
print(r)
print(ary[tuple(idx)])
if __name__ == '__main__':
test()
@funkindy
Copy link

funkindy commented Jul 5, 2019

@sklam, there is a function to_fixed_tuple in numba.unsafe.ndarray. It works pretty much similar and in does accept literals:

import numba
from numba.unsafe.ndarray import to_fixed_tuple

arr = np.array([1, 2, 3, 4, 5])
arr_length = len(arr)

@numba.njit
def foo(arr):
    l = 5
    return to_fixed_tuple(arr, l)

foo(arr)  # (1, 2, 3, 4, 5)

But i cant find away to pass arr length into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment