Skip to content

Instantly share code, notes, and snippets.

@sklam
Created March 6, 2017 18:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
Save sklam/830fe01343ba95828c3b24c391855c86 to your computer and use it in GitHub Desktop.
numba intrinsic for creating fixed-size tuple from an array
import numpy as np
from numba import types, njit, cgutils, typing
from numba.extending import intrinsic
@intrinsic
def index_tuple_for_array(tyctx, ary, idxary):
"""
Implement an intrinsic to return a fix-sized tuple given the target array
and the index-array.
Equivalent to `tuple(idxary)`
Worksaround the lack of dynamic-sized-tuple
"""
# This is the typing level.
# We'll setup the type and constant information here
nd = ary.ndim # size of the tuple
tupty = types.UniTuple(dtype=types.intp, count=nd) # the tuple
funtion_sig = tupty(ary, idxary) # the function signature for this intrinsic
def codegen(cgctx, builder, signature, args):
# This is the implementation defined using LLVM builder
lltupty = cgctx.get_value_type(tupty)
tup = cgutils.get_null_value(lltupty)
[_, idxaryval] = args
def array_checker(a):
if a.size != nd:
raise IndexError("index array size mismatch")
# compile and call array_checker
cgctx.compile_internal(builder, array_checker, types.none(idxary),
[idxaryval])
def array_indexer(a, i):
return a[i]
# loop to fill the tuple
for i in range(nd):
dataidx = cgctx.get_constant(types.intp, i)
# compile and call array_indexer
data = cgctx.compile_internal(builder, array_indexer,
idxary.dtype(idxary, types.intp),
[idxaryval, dataidx])
tup = builder.insert_value(tup, data, i)
return tup
return funtion_sig, codegen
def test():
@njit
def foo(ary, idx):
return ary[index_tuple_for_array(ary, idx)]
ary = np.arange(10, dtype=np.intp).reshape(5, 2)
idx = np.asarray([2, 1], dtype=np.intp)
r = foo(ary, idx)
print(r)
print(ary[tuple(idx)])
if __name__ == '__main__':
test()
@sklam
Copy link
Author

sklam commented Mar 6, 2017

[SELF NOTE]

To solve the general case with maximum usability:

def foo(ary, idxary):
    idx = tuple(idxary)
    return ary[idx]

a redesign of the typesystem to make use of type-variable is required. The type-inference will need to make two passes, a bottom-up pass and a top-down pass.

Let's say the getitem has signature Array<Tdtype, Tndim>, Tuple<Tndim, intp> -> Tdtype, where Tdtype and Tndim are type variables for the dtype type and the dimensionality of the array of the size of the tuple. The bottom-up pass will infer that:

ary[idx] :: Array<Tdtype, Tndim>, Tuple<Tndim, intp> -> Tdtype   # by signature of getitem
idx :: Tuple<Tndim, intp>      # by propagation 
foo :: Array<Tdtype, Tndim>, Array<intp, 1> -> Tdtype   # by rule tuple :: Array<Tdtype, 1>  -> Tuple<Tndim, intp>

This gives a generic shape for the function foo and leaving two type-variables: Tdtype, Tndim The top-down pass can then fill them since the argument must be fully materialized (no type variable). Given arguments ary :: Array<float32, 2> and idxary :: Array<intp, 2>, it will assign Tdtype=float32 and Tndim=2; thus, eliminating all type-variables in the function. With that, we determined the tuple size of idx without help from the user.

(This ignored overloading, which should be still be doable.)

@DrTodd13
Copy link

Why does getitem have to have this signature? Can't the index tuple have lower dimensionality than the array? I tried it in nopython mode and a lower dimensional tuple seems to be accepted.

@yydcool
Copy link

yydcool commented Jan 31, 2019

This is much easier to use than "to_fixed_tuple" from numba.unsafe.ndarray.
I use this function to access an array. The dimension of the array is depend on input data and can not be decided at compile time.
Without help of this function, I can only access by create a 1-dim array of length X * Y * Z, and compute the index by a[i * Y * Z+j*Z+k]
I wonder why not merge this numba?

@funkindy
Copy link

funkindy commented Jul 3, 2019

Hi! This is great but i cant find a way to pass nd as an argument to index_tuple_for_array from the outside njit function? I am trying to simplify it a bit like this:

@numba.extending.intrinsic
def array_to_tuple_1d(tyctx, ary, nd):
    """
    Similar to tuple(ary)
    """

    # This is the typing level.
    # We'll setup the type and constant information here
    tupty = numba.types.UniTuple(dtype=numba.types.int32, count=nd)    # the tuple
    funtion_sig = tupty(ary)                      # the function signature for this intrinsic

    def codegen(cgctx, builder, signature, args):
        # This is the implementation defined using LLVM builder
        lltupty = cgctx.get_value_type(tupty)
        tup = numba.cgutils.get_null_value(lltupty)

        [aryval] = args

        def array_indexer(a, i):
            return a[i]

        # loop to fill the tuple
        for i in range(nd):
            dataidx = cgctx.get_constant(numba.types.intp, i)
            # compile and call array_indexer
            data = cgctx.compile_internal(builder, array_indexer,
                                          ary.dtype(ary, numba.types.intp),
                                          [aryval, dataidx])
            tup = builder.insert_value(tup, data, i)
        return tup

    return funtion_sig, codegen

@numba.njit
def arr2tup(arr):
    return array_to_tuple_1d(arr, arr.size)

arr2tup(np.array([1, 2, 3]))

and i get TypeError: %d format: a number is required, not Integer, tried some workarounds but didnt succeed.

@sklam
Copy link
Author

sklam commented Jul 3, 2019

At def array_to_tuple_1d(tyctx, ary, nd):, nd is a type (Integer) and not an actual value. In numba.types.UniTuple(dtype=numba.types.int32, count=nd), count should be a value.

@funkindy
Copy link

funkindy commented Jul 4, 2019

Thank you, i see. Can you please take a quick look at related issue numba/numba#4265? I am looking for the way to explicitly pass the value (if it is possible).

@funkindy
Copy link

funkindy commented Jul 5, 2019

@sklam, there is a function to_fixed_tuple in numba.unsafe.ndarray. It works pretty much similar and in does accept literals:

import numba
from numba.unsafe.ndarray import to_fixed_tuple

arr = np.array([1, 2, 3, 4, 5])
arr_length = len(arr)

@numba.njit
def foo(arr):
    l = 5
    return to_fixed_tuple(arr, l)

foo(arr)  # (1, 2, 3, 4, 5)

But i cant find away to pass arr length into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment