Skip to content

Instantly share code, notes, and snippets.

@luk-f-a
Last active April 23, 2020 14:15
Show Gist options
  • Save luk-f-a/77fb61b25cadbc497337ab9b53b1d0c9 to your computer and use it in GitHub Desktop.
Save luk-f-a/77fb61b25cadbc497337ab9b53b1d0c9 to your computer and use it in GitHub Desktop.
Function subtyping in Numba

Some use-cases of function subtyping

Currently, numba functions display a lot of behaviours expected from true first-class functions. There are, however, 3 behaviours which we see as useful and are currently missing:

  1. iterating over a sequence of functions
  2. avoiding re-compilation of higher-order functions
  3. ability to cache higher-order functions

iterating over a sequence of functions

This is an extremely common pattern:

@njit
def foo(x):
    return x + 1

@njit
def foo2(x):
    return x + 2

@njit
def bar(fcs, val):
    x = 0
    for fc in fcs:
        x += fc(val)
    return x

bar((foo, foo2), 3.5)

At the moment the above does not work in Numba. literal_unroll could be use for simple cases, but not for something like

@njit
def bar(fcs1, fcs2, val):
    x = 0
    for fc1, fc2 in zip(fcs1, fcs2):
        x += fc1(val)*fc2(val)
    return x

bar((foo, foo2), 3.5)

The current implementation of first-class function types would allow the following:

@cfunc(int64(int64))
def foo(x):
    return x + 1

@njit
def foo2(x):
    return x + 2

bar((foo, foo2), 3.5)

However, this is not practical when the user of the functions is not the creator of them (as in the case of libraries or frameworks). One could compile for one type and the disable compilation.

@njit
def foo(x):
    return x + 1

@njit
def foo2(x):
    return x + 2

foo.compile(int64(int64))
foo.disable_compile()
bar((foo, foo2), 3.5)

However, this would have a strong side-effect on the original function, breaking future potential calls with other types.

foo(1.2) #fails

One could extract the py_function

@njit
def foo(x):
    return x + 1

@njit
def foo2(x):
    return x + 2

foo_int = cfunc((int64(int64))(foo.py_func)

bar((foo_int, foo2), 3.5)

but then one must keep track of every version of foo one that has used, or keep recompiling over and over the same function (creating a problem in use case 2).

In summary, while some use cases can be served with the current features, more general cases cannot be accomodated in a smooth way.

The ideal case

Ultimately, the following should happen

@njit
def bar(fcs, val):
    x = 0
    for fc in fcs:
        x += fc(val)
    return x

bar((foo, foo2), 3.5)
bar.disable_compile()
bar((foo2, foo), 3.5)

without the need for any annotation.

However, this might not be easy to achieve, since it could require changes to the type inference stage (my guess is that it would need to allow input types to be type variables, and not type constraints with a known value).

If a user annotation is required, then the following is much better for some (many?) use cases:

@njit
def foo(x):
    return x + 1

@njit
def foo2(x):
    return x + 2

@njit(int64(UniTuple(FunctionType(int64(int64)))))
def bar(fcs, val):
    x = 0
    for fc in fcs:
        x += fc(val)
    return x

bar((foo, foo2), 3.5)
bar((foo2, foo1), 3.5) #works
bar((foo2, foo2), 3.5) #works
bar((foo2, foo3), 3.5) #works

The basic logic is that Dispatcher is an intersection type. For example, type(foo)=Dispatcher(foo)= int64->int64 & float64->float64 & ....

PR #5579 implements the following subtyping rule: if foo: T1 & T2 then foo<:T1. In simple terms, if foo can be compiled for int and float then foo should be accepted as a first-class function type by a function that requires int->int. This allows a seamless transition from Dispatcher type to FunctionType, which enables first-class behaviour.

As a consequence, and since Numba already implements the subtyping rule for tuples (S1, S2)<:(T1, T2) if S1<:T1 and S2<:T2, then automatically a tuple of dispatchers would be a subtype of a tuple of FunctionType with any signature supported by all dispatchers in the tuple.

Implementation

Thanks to Numba having a solid cast machinery in place, the implementation of this feature only requires 4 lines of code:

def can_convert_to(self, typingctx, other):
    if isinstance(other, types.FunctionType):
        if self.dispatcher.get_compile_result(other.signature):
            return Conversion.safe

Currently out of scope

Function types also follow the subtyping rule T1->S2 <: S1->T2 if S1<:T1 and S2<:T2, ie functions are contravariant in their inputs and covariant for their outputs.

The current PR does not implement that, but a future PR could do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment