Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active July 24, 2018 17:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/1f0a308a06cd96df20879a1ddb8f0006 to your computer and use it in GitHub Desktop.
Save shoyer/1f0a308a06cd96df20879a1ddb8f0006 to your computer and use it in GitHub Desktop.
nep-18-example-implementation.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "nep-18-example-implementation.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/gist/shoyer/1f0a308a06cd96df20879a1ddb8f0006/notebook.ipynb)"
]
},
{
"metadata": {
"id": "0VUCrGudBGyN",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# NEP 18 example implementation\n",
"\n",
"Author: Stephan Hoyer (shoyer@google.com)\n",
"\n",
"Date: June 10, 2018\n",
"\n",
"See [the NEP](http://www.numpy.org/neps/nep-0018-array-function-protocol.html) for full context and details.\n",
"\n",
"## Implementation of `__array_function__` machinery\n",
"\n",
"Our goals here are:\n",
"1. Correctness\n",
"2. Performance for the typical case of no overloads\n",
"3. Performance for large numbers of arguments\n",
" - This is important for overloading functions like `np.concatenate`, which could involve thousands or tens of thousands of arguments to check.\n",
"\n",
"Note that for maximum performance, we will probably write the actual implementations of these functions (at least `get_overloaded_types_and_args` and `try_array_function_override`) in C."
]
},
{
"metadata": {
"id": "zITJ-67gA30f",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import functools\n",
"import inspect\n",
"import numpy as np\n",
"import six\n",
"import sys\n",
"\n",
"\n",
"class ndarray(np.ndarray):\n",
" \"\"\"Updated version of numpy.ndarray.\"\"\"\n",
" def __array_function__(self, func, types, args, kwargs):\n",
" # Cannot handle items that have __array_function__ other than our own.\n",
" for t in types:\n",
" if (hasattr(t, '__array_function__') and\n",
" t.__array_function__ is not ndarray.__array_function__):\n",
" return NotImplemented\n",
"\n",
" # Arguments contain no overrides, so we can safely call the\n",
" # overloaded function again.\n",
" return func(*args, **kwargs)\n",
"\n",
"\n",
"def get_overloaded_types_and_args(relevant_args):\n",
" \"\"\"Returns a list of arguments on which to call __array_function__.\n",
" \n",
" __array_function__ implementations should be called in order on the return\n",
" values from this function.\n",
" \"\"\"\n",
" # Runtime is O(num_arguments * num_unique_types)\n",
" overloaded_types = []\n",
" overloaded_args = []\n",
" for arg in relevant_args:\n",
" arg_type = type(arg)\n",
" if arg_type not in overloaded_types:\n",
" try:\n",
" array_function = arg_type.__array_function__\n",
" except AttributeError:\n",
" continue\n",
"\n",
" overloaded_types.append(arg_type)\n",
"\n",
" if array_function is not ndarray.__array_function__:\n",
" index = len(overloaded_args)\n",
" for i, old_arg in enumerate(overloaded_args):\n",
" if issubclass(arg_type, type(old_arg)):\n",
" index = i\n",
" break\n",
" overloaded_args.insert(index, arg)\n",
"\n",
" return overloaded_types, overloaded_args\n",
"\n",
"\n",
"def full_name(obj):\n",
" return f'{obj.__module__}.{obj.__qualname__}'\n",
" \n",
"\n",
"def attempt_augmented_error_message(error, append_message):\n",
" \"\"\"Attempt to recreate an error with an appended message.\"\"\"\n",
" try:\n",
" return type(error)(error.args[0] + append_message, *error.args[1:])\n",
" except Exception:\n",
" return error\n",
" \n",
"\n",
"def try_array_function_override(func, relevant_arguments, args, kwargs):\n",
" # TODO: consider simplifying the interface, to only require either `types`\n",
" # (by calling __array_function__ a classmethod) or `overloaded_args` (by\n",
" # dropping `types` from the signature of __array_function__)\n",
" types, overloaded_args = get_overloaded_types_and_args(relevant_arguments)\n",
" if not overloaded_args:\n",
" return False, None\n",
"\n",
" for overloaded_arg in overloaded_args:\n",
" # Note that we're only calling __array_function__ on the *first*\n",
" # occurence of each argument type. This is necessary for reasonable\n",
" # performance with a possibly long list of overloaded arguments, for\n",
" # which each __array_function__ implementation might reasonably need to\n",
" # check all argument types.\n",
" try:\n",
" result = overloaded_arg.__array_function__(\n",
" func, types, args, kwargs)\n",
" except Exception as error:\n",
" # Ensure the type of the overloaded argument ends up in the\n",
" # traceback\n",
" message = (\" [while calling {!r} implementation of {!r}]\"\n",
" .format(full_name(type(overloaded_arg)),\n",
" full_name(func)))\n",
" new_error = attempt_augmented_error_message(error, message)\n",
" # Would probably need to use six to do this sanely on Python 2:\n",
" # https://stackoverflow.com/questions/9157210/\n",
" raise new_error.with_traceback(error.__traceback__) from None\n",
"\n",
" if result is not NotImplemented:\n",
" return True, result\n",
"\n",
" raise TypeError('no implementation found for {} on types that implement '\n",
" '__array_function__: {}'\n",
" .format(func, list(map(type, overloaded_args))))\n",
"\n",
"\n",
"def array_function_dispatch(dispatcher):\n",
" \"\"\"Wrap a function for dispatch with the __array_function__ protocol.\"\"\"\n",
" def decorator(func):\n",
" @functools.wraps(func)\n",
" def new_func(*args, **kwargs):\n",
" relevant_arguments = dispatcher(*args, **kwargs)\n",
" success, value = try_array_function_override(\n",
" new_func, relevant_arguments, args, kwargs)\n",
" if success:\n",
" return value\n",
" return func(*args, **kwargs)\n",
" return new_func\n",
" return decorator\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "FzAfeG7MzFpH",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Unit tests for `get_overloaded_types_and_args`"
]
},
{
"metadata": {
"id": "3NsvGGOONWwZ",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def return_self(self, *args, **kwargs):\n",
" return self\n",
"\n",
"def return_not_implemented(self, *args, **kwargs):\n",
" return NotImplemented\n",
"\n",
"class A:\n",
" __array_function__ = return_self\n",
"\n",
"class B(A):\n",
" __array_function__ = return_self\n",
"\n",
"class C(A):\n",
" __array_function__ = return_self\n",
"\n",
"class D:\n",
" __array_function__ = return_self\n",
"\n",
"a = A()\n",
"b = B()\n",
"c = C()\n",
"d = D()\n",
"\n",
"def get_overloaded_args(relevant_args):\n",
" types, args = get_overloaded_types_and_args(relevant_args)\n",
" return args\n",
"\n",
"assert get_overloaded_args([1]) == []\n",
"assert get_overloaded_args([a]) == [a]\n",
"assert get_overloaded_args([a, 1]) == [a]\n",
"assert get_overloaded_args([a, a, a]) == [a]\n",
"assert get_overloaded_args([a, d, a]) == [a, d]\n",
"assert get_overloaded_args([a, b]) == [b, a]\n",
"assert get_overloaded_args([b, a]) == [b, a]\n",
"assert get_overloaded_args([a, b, c]) == [b, c, a]\n",
"assert get_overloaded_args([a, c, b]) == [c, b, a]\n",
"\n",
"class SubNDArray(ndarray):\n",
" __array_function__ = return_self\n",
"\n",
"array = np.array(1).view(ndarray)\n",
"assert get_overloaded_types_and_args([array]) == ([ndarray], [])\n",
"assert get_overloaded_types_and_args([a, array, 1]) == ([A, ndarray], [a])\n",
"\n",
"subarray = np.array(1).view(SubNDArray)\n",
"assert get_overloaded_args([array, subarray]) == [subarray]\n",
"assert get_overloaded_args([subarray, array]) == [subarray]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "AVtqaROdxR_T",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Example function implementations\n",
"\n",
"Note that functions like `np.concatenate` are written in C, so we'll need to write these wrappers in C, too, unless we're OK with performance hit of doing all the wrapping logic in Python."
]
},
{
"metadata": {
"id": "t0xgmk4IMH5C",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"\n",
"\n",
"def _broadcast_to_dispatcher(array, shape, subok=None):\n",
" return (array,)\n",
"\n",
"@array_function_dispatch(_broadcast_to_dispatcher)\n",
"def broadcast_to(array, shape, subok=False):\n",
" return np.broadcast_to(array, shape, subok)\n",
"\n",
"\n",
"def _concatenate_dispatcher(arrays, axis=None, out=None):\n",
" for array in arrays:\n",
" yield array\n",
" if out is not None:\n",
" yield out\n",
" \n",
"@array_function_dispatch(_concatenate_dispatcher)\n",
"def concatenate(arrays, axis=0, out=None):\n",
" return np.concatenate(arrays, axis=axis, out=out)\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "BuD9RXfX7DtR",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# verify that we can pickle these functions\n",
"# note: using functools.wraps and the decorator appears to be critical!\n",
"import pickle\n",
"assert pickle.loads(pickle.dumps(broadcast_to)) is broadcast_to"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "M6mEHx206U8I",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## MyArray implementations"
]
},
{
"metadata": {
"id": "i8R0Nu4-6W6T",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 674
},
"outputId": "c6f04a25-1d73-42d6-9bff-55261c93370a"
},
"cell_type": "code",
"source": [
"HANDLED_FUNCTIONS = {}\n",
"\n",
"class MyArray:\n",
" def __array_function__(self, func, types, args, kwargs):\n",
" if func not in HANDLED_FUNCTIONS:\n",
" return NotImplemented\n",
" if not all(issubclass(t, MyArray) for t in types):\n",
" return NotImplemented\n",
" return HANDLED_FUNCTIONS[func](*args, **kwargs)\n",
"\n",
"def implements(numpy_function):\n",
" \"\"\"Register an __array_function__ implementation for MyArray objects.\"\"\"\n",
" def decorator(func):\n",
" HANDLED_FUNCTIONS[numpy_function] = func\n",
" return func\n",
" return decorator\n",
"\n",
" \n",
"# dummy implementation to show how overloads work with new/unexpected arguments\n",
"@implements(concatenate)\n",
"def _(arrays):\n",
" pass\n",
"\n",
"my_array = MyArray()\n",
"concatenate([my_array]) # works\n",
"concatenate([my_array], axis=0) # not supported"
],
"execution_count": 5,
"outputs": [
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-02956b4e04d5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mmy_array\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMyArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmy_array\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# works\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmy_array\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# not supported\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-1-d50f18203b4d>\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mrelevant_arguments\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdispatcher\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m success, value = try_array_function_override(\n\u001b[0;32m--> 107\u001b[0;31m new_func, relevant_arguments, args, kwargs)\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msuccess\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-1-d50f18203b4d>\u001b[0m in \u001b[0;36mtry_array_function_override\u001b[0;34m(func, relevant_arguments, args, kwargs)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;31m# Would probably need to use six to do this sanely on Python 2:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# https://stackoverflow.com/questions/9157210/\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mnew_error\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-1-d50f18203b4d>\u001b[0m in \u001b[0;36mtry_array_function_override\u001b[0;34m(func, relevant_arguments, args, kwargs)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m result = overloaded_arg.__array_function__(\n\u001b[0;32m---> 80\u001b[0;31m func, types, args, kwargs)\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# Ensure the type of the overloaded argument ends up in the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-5-02956b4e04d5>\u001b[0m in \u001b[0;36m__array_function__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0missubclass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMyArray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtypes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mHANDLED_FUNCTIONS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mimplements\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_function\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: _() got an unexpected keyword argument 'axis' [while calling '__main__.MyArray' implementation of '__main__.concatenate']"
]
}
]
},
{
"metadata": {
"id": "-1FtHR3T_gi1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
},
"outputId": "9ffb66ae-d2fd-494e-a060-8df7693638fb"
},
"cell_type": "code",
"source": [
"concatenate([my_array], new_arg=True) # not supported by NumPy"
],
"execution_count": 6,
"outputs": [
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-566bad0d5bff>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmy_array\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_arg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# not supported by NumPy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-1-d50f18203b4d>\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mnew_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mrelevant_arguments\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdispatcher\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m success, value = try_array_function_override(\n\u001b[1;32m 107\u001b[0m new_func, relevant_arguments, args, kwargs)\n",
"\u001b[0;31mTypeError\u001b[0m: _concatenate_dispatcher() got an unexpected keyword argument 'new_arg'"
]
}
]
},
{
"metadata": {
"id": "6wIZ-SMSzRqY",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Micro benchmarks\n",
"\n",
"It's important that the overhead of `__array_function__` is minimal in the typical case of no overloads."
]
},
{
"metadata": {
"id": "C-dru3tyiJ3D",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "11ad8126-9473-4cfb-f781-1059bf7b01d8"
},
"cell_type": "code",
"source": [
"array = np.array(1)\n",
"shape = (2,)\n",
"%timeit np.broadcast_to(array, shape)\n",
"%timeit broadcast_to(array, shape)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 9.35 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 4.42 µs per loop\n",
"The slowest run took 5.82 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 6.46 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "UpSxEJg9KK4M",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "442a358a-2bc3-45c7-9386-7ce48d02e699"
},
"cell_type": "code",
"source": [
"arrays = [np.array([1]), np.array([2])]\n",
"%timeit np.concatenate(arrays)\n",
"%timeit concatenate(arrays)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 1221.88 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 869 ns per loop\n",
"The slowest run took 10.83 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 3.98 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "KlSC5g5LO6Su",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "182e7ac3-64ca-415c-af0d-70f59c4659b9"
},
"cell_type": "code",
"source": [
"many_arrays = [np.array([1]), np.array([2])] * 10000\n",
"%timeit np.concatenate(many_arrays)\n",
"%timeit concatenate(many_arrays)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"100 loops, best of 3: 3.16 ms per loop\n",
"100 loops, best of 3: 16.7 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "yCXlh20M0ohf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 306
},
"outputId": "3476fef8-f844-4ff2-ce95-a3e60d4b283a"
},
"cell_type": "code",
"source": [
"arrays = [np.array([1]), np.array([2])]\n",
"stats = %prun -r for _ in range(100000): concatenate(arrays)\n",
"stats.print_stats()"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
" 800003 function calls in 0.569 seconds\n",
"\n",
" Ordered by: internal time\n",
"\n",
" ncalls tottime percall cumtime percall filename:lineno(function)\n",
" 100000 0.197 0.000 0.235 0.000 <ipython-input-1-d50f18203b4d>:22(get_overloaded_types_and_args)\n",
" 100000 0.154 0.000 0.154 0.000 {built-in method numpy.core.multiarray.concatenate}\n",
" 100000 0.069 0.000 0.522 0.000 <ipython-input-1-d50f18203b4d>:103(new_func)\n",
" 1 0.047 0.047 0.569 0.569 <string>:1(<module>)\n",
" 300000 0.038 0.000 0.038 0.000 <ipython-input-3-8b2f2296c910>:12(_concatenate_dispatcher)\n",
" 100000 0.035 0.000 0.270 0.000 <ipython-input-1-d50f18203b4d>:64(try_array_function_override)\n",
" 100000 0.029 0.000 0.183 0.000 <ipython-input-3-8b2f2296c910>:18(concatenate)\n",
" 1 0.000 0.000 0.569 0.569 {built-in method builtins.exec}\n",
" 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}\n",
"\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pstats.Stats at 0x7fd3456a0e80>"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"metadata": {
"id": "eG0v2L--lIjT",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 425
},
"outputId": "0ea31b40-800d-4a4e-95ac-02ed6c53be4d"
},
"cell_type": "code",
"source": [
"# other micro-benchmarks, for context\n",
"x = np.arange(10)\n",
"%timeit np.asarray(x)\n",
"%timeit x[x]\n",
"%timeit np.concatenate([x, x])\n",
"%timeit np.stack([x, x])\n",
"%timeit x.sum()\n",
"%timeit np.sum(x)\n",
"%timeit np.mean(x)\n",
"%timeit np.sin(x)\n",
"%timeit np.unique(x)\n",
"%timeit np.broadcast_to(x, (1, 10))\n",
"%timeit np.transpose(x)\n",
"%timeit np.moveaxis(x, 0, -1)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 16.74 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 365 ns per loop\n",
"The slowest run took 44.88 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 242 ns per loop\n",
"The slowest run took 59.54 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 938 ns per loop\n",
"The slowest run took 6.92 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 5.39 µs per loop\n",
"The slowest run took 34.62 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 1.68 µs per loop\n",
"The slowest run took 11.59 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 2.72 µs per loop\n",
"The slowest run took 8.03 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 7.06 µs per loop\n",
"The slowest run took 50.44 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 1.22 µs per loop\n",
"The slowest run took 9.14 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 6.52 µs per loop\n",
"The slowest run took 6.86 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 4.58 µs per loop\n",
"The slowest run took 24.69 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 644 ns per loop\n",
"The slowest run took 7.18 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100000 loops, best of 3: 5.3 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "2tL6-I3pnpUt",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def dummy_try_array_function_override(\n",
" func, relevant_arguments, args, kwargs):\n",
" return False, None\n",
"\n",
"def dummy_dispatch(dispatcher):\n",
" def decorator(func):\n",
" @functools.wraps(func)\n",
" def new_func(*args, **kwargs):\n",
" relevant_arguments = dispatcher(*args, **kwargs)\n",
" success, value = dummy_try_array_function_override(\n",
" new_func, relevant_arguments, args, kwargs)\n",
" if success:\n",
" return value\n",
" return func(*args, **kwargs)\n",
" return new_func\n",
" return decorator\n",
" \n",
"def f(x):\n",
" pass\n",
"\n",
"def _dispatcher(x):\n",
" return (x,)\n",
" \n",
"@dummy_dispatch(_dispatcher)\n",
"def g(x):\n",
" pass"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "cy6Haybcn49V",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "bfe495ef-e724-4a22-abfa-63497c89878f"
},
"cell_type": "code",
"source": [
"%timeit f(1)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 13.60 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"10000000 loops, best of 3: 84.8 ns per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "DLITbTVYn50_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "13e18f3a-69dc-4fec-c38e-57d58b05077e"
},
"cell_type": "code",
"source": [
"%timeit g(1)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 9.63 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 3: 567 ns per loop\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "PM4KRoOQ1qmE",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Micro benchmark conclusions:\n",
"\n",
"- Adding overloads with `try_array_function_override` adds about 2-3 us of overload per function call.\n",
" - This is fine for functions written in Python (e.g., `np.broadcast_to`), but *could* be significant for functions written in C (e.g., `np.concatenate`).\n",
" - It's unclear how bad performance degradation would be if we wrote this in C.\n",
"- The explicit decorator dispatch_with is really clean and just as fast as calling try_array_function_override directly.\n",
" - The only downside is that the use of functools.wraps means that decorated functions lose an inspectable signature on Python 2. But this is probably worth it, given how soon NumPy will be Python 3 only."
]
}
]
}
@mattip
Copy link

mattip commented Jul 24, 2018

Is it time to turn this into a PR or somehow move it forward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment