Skip to content

Instantly share code, notes, and snippets.

@luk-f-a
Created April 13, 2020 12:18
Show Gist options
  • Save luk-f-a/9063ee1eea50166141a21ee852fb4812 to your computer and use it in GitHub Desktop.
Save luk-f-a/9063ee1eea50166141a21ee852fb4812 to your computer and use it in GitHub Desktop.
Proposal for subtyping in Numba
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numba proposal: (local) subtyping"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Motivation:** Numba compiles overloads based on the specific type of the inputs. Each `CompileResult` is associated (via the `overloads` dictionary) to a set of input argument types. There are times, however, when two or more types could share one `CompileResult`. At the moment, there is no way to prevent new compilations for each of those types."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Status**: the current proposal is a DRAFT. Different alternatives are discussed as a means to gather feedback. Once there is an agreement, I can do the work on the PR."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Summary and scope**: we propose a general method to allow the user to indicate sets of types that can safely share a `CompileResult`. To avoid large changes to Numba's type system and type inference, the proposed method will probably have to be a) local to each `CPUDispatcher` instance, b) triggered (opt-in) by the user who must ensure that the types can effectively work under the same `CompileResult`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Out of scope / non goals**: it's not the intention to propose a general mechanism for safe subtyping across all of numba's type system. In particular, subtyping is not considered in the type inference phase. Subtyping is only considered at the boundaries, eg in the `overloads` of `Dispatcher`.\n",
"\n",
"CAVEAT: **The code provided in the many examples of this proposal would not actually run in some cases. Most of the examples are written for exposition, not for perfection. If this proposal is accepted, there will be time for a correct PR with better documentation.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Motivation\n",
"Consider the following example:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"2.0\n",
"There are 2 signatures\n"
]
}
],
"source": [
"from numba import njit, typeof\n",
"import numpy as np\n",
"\n",
"rec1 = np.array([1], dtype=[('a', 'f8')])[0]\n",
"rec2 = np.array([(2,3)], dtype=[('a', 'f8'), ('b', 'f8')])[0]\n",
"\n",
"@njit\n",
"def foo(rec):\n",
" return rec['a']\n",
"\n",
"print(foo(rec1))\n",
"print(foo(rec2))\n",
"print(f\"There are {len(foo.signatures)} signatures\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The second compilation is unnecesary, since `a` is present in the first position in both types. However, given the current way `CPUDispatcher.compile()` works, this fails:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n"
]
},
{
"ename": "TypeError",
"evalue": "No matching definition for argument type(s) Record(b[type=float64;offset=0],c[type=float64;offset=8];16;False)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-78aaab79af76>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfoo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrec1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mfoo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable_compile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfoo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrec2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda3/envs/latest_versions38/lib/python3.8/site-packages/numba/dispatcher.py\u001b[0m in \u001b[0;36m_explain_matching_error\u001b[0;34m(self, *args, **kws)\u001b[0m\n\u001b[1;32m 572\u001b[0m msg = (\"No matching definition for argument type(s) %s\"\n\u001b[1;32m 573\u001b[0m % ', '.join(map(str, args)))\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_search_new_conversions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkws\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: No matching definition for argument type(s) Record(b[type=float64;offset=0],c[type=float64;offset=8];16;False)"
]
}
],
"source": [
"rec1 = np.array([1], dtype=[('a', 'f8')])[0]\n",
"rec2 = np.array([(2,3)], dtype=[('b', 'f8'), ('c', 'f8')])[0]\n",
"\n",
"@njit\n",
"def foo(rec):\n",
" return rec['a']\n",
"print(foo(rec1))\n",
"foo.disable_compile()\n",
"print(foo(rec2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a hack, one could do the following:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"2.0\n"
]
}
],
"source": [
"rec1 = np.array([1], dtype=[('a', 'f8')])[0]\n",
"rec2 = np.array([(2,3)], dtype=[('b', 'f8'), ('c', 'f8')])[0]\n",
"\n",
"@njit\n",
"def foo(rec):\n",
" return rec['a']\n",
"\n",
"print(foo(rec1))\n",
"\n",
"cres = foo.overloads[(typeof(rec1),)]\n",
"args = (typeof(rec2),)\n",
"sig = [a._code for a in args]\n",
"foo._insert(sig, cres.entry_point, cres.objectmode, cres.interpmode)\n",
"foo.overloads[args] = cres\n",
"foo.disable_compile()\n",
"\n",
"print(foo(rec2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, this requires inserting an overload for each possible compatible record, and does not allow the general subtyping of record types.\n",
"\n",
"# 2. Possible solution\n",
"A possible better way would be to allow users to provide a object that defines a set of types that can re-use the `CompileResult` for a given type. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, the `compile` method of `Dispatcher` executes the search in the `overloads` dictionary by checking the input arguments against previous compilations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"existing = self.overloads.get(tuple(args))\n",
"if existing is not None:\n",
" return existing.entry_point\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Under this proposal, `compile` would add something like this:\n",
"```python\n",
"# first, find if the input arguments match any supertype\n",
"matched_sty = None\n",
"for super_type in self.super_types:\n",
" for arg, arg_supertype in zip(args, super_type):\n",
" if arg not in arg_supertype:\n",
" break\n",
" else:\n",
" matched_sty = super_type\n",
" break\n",
"\n",
"# second, find if the matched supertype matches an existing CompileResult\n",
"if matched_sty:\n",
" existing = None\n",
" for ovrl_sig, ovrl_compile_res in self.overloads:\n",
" for arg, arg_supertype in zip(ovrl_sig, matched_sty):\n",
" if arg not in arg_supertype:\n",
" break\n",
" else:\n",
" existing = ovrl_compile_res\n",
" break \n",
" \n",
" if existing is not None:\n",
" return existing.entry_point\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A user that wants to use this mechanism to solve the example above would provide a set of types that can share a `CompileResult`:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"foo.define_supertype(({typeof(rec1), typeof(rec2)}, ))\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the option above, only the subtypes are declared. One of them might or not be the supertype, but it is not necessary to identify the supertype. We might say that this is nominal subtyping over an anonymous supertype."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The option above requires all subtypes to be explicitly declared. To solve it in a more general way, the user would do the following:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"class RecordWithA:\n",
" \"\"\"\n",
" The set of all records whose first field is 'a' and is an integer\n",
" \"\"\"\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def __contains__(self, x):\n",
" if isinstance(types.Record) and 'a' in x.fields:\n",
" if x.fields['a'].type == int64 and x.fields['a'].offset==0:\n",
" return True\n",
" return False\n",
"\n",
"rec_a_foo = RecordWithA()\n",
"\n",
"foo.define_supertype((rec_a_foo, ))\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a sort of structural subtyping, but requiring a user declaration, via a Python class. This class either explictly contains/receives the supertype or at least implies what the supertype is. In the case above, it's ``Record(a[type=int64;offset=0];8;False)``. As in the option before, this is **local** subtyping, because it is limited to the `Dispatcher` that the user chooses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A more structured way of doing this is discussed at the end, together with other alternatives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Application to other types: \n",
"\n",
"NamedTuples and ordinary tuples can benefit from subtyping in a similar way as records above. Arrays probably won't benefit from it, since slices are a cheap way to obtain the same behaviour."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Application to first-class function type"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the benefits of having first-class functions is being able to avoid the recompilation of higher-order functions. This not yet possible for Dispatcher objects in the current implementation of first-class function type. With a `cfunc` this is easy to set up, because they have a precise function type `input -> output`. Dispatcher objects types, however, hold many signatures.\n",
"\n",
"Families of dispatcher types could be created by the user to signal the compiler that the given dispatchers can be compiled under a common signature, and can therefore share a single `CompileResult` of a higher-order function. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"function_set1 = {njit_function1._numba_type_, njit_function2._numba_type_}```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"class FunctionSuperType:\n",
" def __init__(self, signature):\n",
" self.signature = signature\n",
" \n",
" def __contains__(self, x):\n",
" if isinstance(x, types.Dispatcher):\n",
" if self.signature in x.signatures:\n",
" return True\n",
" try:\n",
" x.compile(self.signature)\n",
" except:\n",
" return False\n",
" return True\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This does not actually work at the moment, because `CompileResult` cannot be shared across Dispatchers, even if they support the same signature.\n",
"Making it work requires the addition of a cast mechanism between Dispatcher instances and a valid `CompileResult` that conforms to a given signature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In fact, we would like to propose a change to the current behaviour of the experimental first-class function types. This proposal extends the main subtyping proposal, but can be evaluated separately. Accepting this secondary proposal on function types, however, does require accepting the main subtyping proposal in some shape or form."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Proposal for function types"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, numba functions display a lot of behaviours expected from true first-class functions. There are, however, 3 behaviours which we see as useful and are currently missing:\n",
"- iterating over a sequence of functions\n",
"- avoiding re-compilation of higher-order functions\n",
"- ability to cache hither-order functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The current experimental implementation of first-class function type allows iterating over a tuple of functions. It requires, however, after PR #5529, that at least one of the elements in the tuple has a precise signature. We see this as an understandable but inconvenient restriction. Lifting this restriction will probably require changes in the way Dispatcher types work, and in the type inference stage, which we guess is unlikely to happen in the short-to-medium term. Here, we propose to move the restriction to a less inconvenient place.\n",
"\n",
"When types are not known in advance, it's hard to provide signatures in advance. If types are dynamically created, then it's impossible, and the signature has to be provided dynamically. Consider the following case\n",
"\n",
"```python\n",
"@njit\n",
"def foo(x):\n",
" return x + 1\n",
"\n",
"@njit\n",
"def bar(fc):\n",
" return fc(1)\n",
"\n",
"bar(foo)\n",
"```\n",
"\n",
"To attach a signature to `foo` would be very unnatural to the flow of this code. We don't know at the point of creation of `foo` what we want to do with it (in terms of types). In fact, `foo` and `bar` might be created by different people, if eg `bar` is part of a framework.\n",
"Since the creator of `bar` holds more information, it makes more sense that the signature is provided together with `bar` and not with `foo`. Consider the following (fictional) syntax:\n",
"\n",
"```python\n",
"@njit\n",
"def bar(fc: int64->int64):\n",
" return fc(1)\n",
"\n",
"```\n",
"\n",
"Since `foo <: (int64->int64)` because `foo` can be compiled under that signature, then it should be possible, at the time of executing `bar(foo)` to perform the following:\n",
"- compile `foo` as `int64->int64`\n",
"- take the `CompileResult` and wrap it in a `CompileResultWAP`, and pass that to `bar`\n",
"- keep the `CompileResult` in overloads associated with the key `FunctionType(int64, int64)` rather than, as currently, with the key `Dispatcher(foo)`. This will avoid future recompilations of `bar` for similar functions.\n",
"\n",
"All this can be done manually by the user, but it makes more sense to provide a mechanism to automate that behaviour, via some form of inversion of control."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider this more complex case\n",
"\n",
"```python\n",
"@njit\n",
"def foo(x):\n",
" return x + 1\n",
"\n",
"@njit\n",
"def foo2(x):\n",
" return x + 2\n",
"\n",
"@njit\n",
"def bar(fcs, val):\n",
" x = 0\n",
" for fc in fcs:\n",
" x += fc(val)\n",
" return x\n",
"\n",
"bar((foo, foo2), 3.5)\n",
"```\n",
"\n",
"and again the fictional syntax (where T is a type variable)\n",
"\n",
"```python\n",
"@njit\n",
"def bar(fcs: Unituple(T->Any), val: T):\n",
" x = 0\n",
" for fc in fcs:\n",
" x += fc(val)\n",
" return x\n",
"```\n",
"\n",
"Again it is possible for the creator of `bar` to provide this type information, but not for the creator of `foo`.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This last example is a good example of use cases 1 and 2 (iterating over a sequence of functions and avoiding re-compilation of higher-order functions). They can be supported under the subtyping proposal by making a small change: adding a cast function.\n",
"\n",
"```python\n",
"class TupleFunctionSuperType:\n",
" def __init__(self, signature):\n",
" self.signature = signature\n",
" \n",
" def __contains__(self, x):\n",
" if isinstance(x, types.BaseAnonymousTuple):\n",
" # I'm going to ignore that the Dispatcher type object is not the same object as the Dispatcher object for readabilty\n",
" # real code would obviously deal with them properly\n",
" # I'm also ignoring any restrictions on the output type of the function, since it was marked Any\n",
" for fc_ty in x.dtype:\n",
" if self.signature in fc_ty.signatures:\n",
" continue\n",
" else:\n",
" try:\n",
" x.compile(self.signature)\n",
" except:\n",
" return False\n",
" else:\n",
" continue\n",
" return True\n",
" \n",
" def cast(self, x):\n",
" f_tys = []\n",
" for fc_ty in x.dtype:\n",
" cres = fc_ty.overloads[self.signature]\n",
" cres_fcty = CompileResultWAP(cres)\n",
" f_tys.append(cres_fcty)\n",
" return Unituple(f_tys)\n",
"```\n",
"\n",
"In this case `TupleFunctionSuperType` has to be instantiated **after** `typeof(val)` is known. We discuss how to create relationships between supertypes across input parameters in the alternatives section."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the ability to define a function supertype and the ability to cast Dispatchers to that supertype, use cases 1 and 2 can be supported.\n",
"\n",
"**Most importantly**:\n",
"- adding this feature adds only a few lines of code in numba itself, the heavy lifting is done in user code.\n",
"- the feature is completely opt-in, it does not create errors or performance regressions for people not using it.\n",
"- the feature is local and only affects the Dispatcher for which it was declared. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Alternatives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining a supertype via Dispatcher methods \n",
"\n",
"`define_supertype()` can accept different inputs depending on whether cases like `bar(fcs: Unituple(T->Any), val: T)`, ie existential? generic? supertypes, should be supported.\n",
"In the simplest case, where each input argument in independent of each other, `define_supertype` could accept either a tuple of supertypes (and this tuple is matched via position with the function arguments) or a dictionary (supertypes are matches by name to the function arguments, and any omitted arguments are assumed to be `Any` (the supertype of all types).\n",
"\n",
"To support a case like `bar(fcs: Unituple(T->Any), val: T)` users would have to provide that behaviour themselves, but the matching code in the `Dispatcher` would have to pass the entire signature at a time, not each parameter one by one.\n",
"That is, instead of looping one parameter at a time:\n",
"\n",
"```python\n",
" for arg, arg_supertype in zip(args, super_type):\n",
" if arg not in arg_supertype:\n",
"```\n",
"\n",
"it would pass all parameters at the same time.\n",
"```python\n",
" if args not in super_type:\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The exact signature and behaviour of `define_supertype()` is up for discussion."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Same idea, but with safeguards to prevent unsafe use"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By placing constraints of the behaviour of supertype classes, certain checks could be performed to avoid some unsafe uses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Global subtyping, structural or nominal"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instead of requiring the user to provide their own `RecordSuperType` class, a general method could be incorporated in the `Record` numba type (currently defined in `npytypes` module). There are several variations of this idea, ranging from completely automated (every record type declared becomes automatically a supertype and applies automatically to all `Dispatcher`s), to different degrees of user opt-in.\n",
"\n",
"However, the more general method the harder it is to avoid errors. Compilation time might also suffer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use function the function signature to declare the supertypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instead of using the `define_supertype` method to declare the applicable subtypes, the user might use the function signature. \n",
"```python\n",
"@njit(ret_ty(supertype1, supertype2))\n",
"def foo(x, y):\n",
"```\n",
"where type1 or type2 could be not only concrete types, but also supertypes or `Any`. For this case to work, supertype classes might be required to inherit from an abstract `SuperType` in order to recognize them as such.\n",
"\n",
"This option is more similar to other programming languages, and more elegant in its declaration, but requires more infrastructure on the numba side to handle all cases. It could even be done via type annotations (but this opens a whole different can of worms):\n",
"\n",
"```python\n",
"@njit\n",
"def foo(x: type1, y: type2)->ret_ty\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "latest_versions38",
"language": "python",
"name": "latest_versions38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment