Last active
November 10, 2021 14:43
-
-
Save RutgerK/6a8468a27fe3e7c1a520c36bf9eac344 to your computer and use it in GitHub Desktop.
Numba type hints
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "9c658ddd-d40b-44c8-a01b-a16a659e7fa9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'0.54.1'" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import typing\n", | |
"from typing import Annotated, Any, get_type_hints, get_args\n", | |
"\n", | |
"from numba import njit, guvectorize, vectorize, __version__\n", | |
"import numpy as np\n", | |
"import inspect\n", | |
"import functools\n", | |
"__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "d1b6b656-7302-4b0e-a1ea-059ba1c1cf08", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def parse_func(f):\n", | |
"\n", | |
" try:\n", | |
" print('type_hints', get_type_hints(f, include_extras=True), sep='\\n', end='\\n\\n')\n", | |
" except:\n", | |
" print('type_hints', '<not available>', sep='\\n', end='\\n\\n')\n", | |
"\n", | |
"\n", | |
" try: \n", | |
" print('argspec', inspect.getfullargspec(f), sep='\\n', end='\\n\\n')\n", | |
" except:\n", | |
" print('argspec', '<not available>', sep='\\n', end='\\n\\n')\n", | |
"\n", | |
"\n", | |
" try:\n", | |
" print('docstring', f.__doc__, sep='\\n', end='\\n\\n')\n", | |
" except:\n", | |
" print('docstring', '<not available>', sep='\\n', end='\\n\\n')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5a7b5d47-4a50-4552-be45-3ee3c9ee24dc", | |
"metadata": {}, | |
"source": [ | |
"# Pure Python" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "0811c9e9-9c04-408d-a315-c7cb2f396e5a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n", | |
"\n", | |
"argspec\n", | |
"FullArgSpec(args=['x', 'y'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Annotated[typing.Any, 'custom_z'], 'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y']})\n", | |
"\n", | |
"docstring\n", | |
"py_func\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"x_type = Annotated[Any, \"custom_x\"]\n", | |
"y_type = Annotated[Any, \"custom_y\"]\n", | |
"z_type = Annotated[Any, \"custom_z\"]\n", | |
"\n", | |
"def py_func(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"py_func\"\"\"\n", | |
" return x+y\n", | |
"\n", | |
"parse_func(py_func)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "95fcaa4a-243d-492d-87e6-c88aa343572b", | |
"metadata": {}, | |
"source": [ | |
"# njit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "7dfa1432-c295-4e92-9150-f053f3ac74ef", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n", | |
"\n", | |
"argspec\n", | |
"FullArgSpec(args=['x', 'y'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Annotated[typing.Any, 'custom_z'], 'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y']})\n", | |
"\n", | |
"docstring\n", | |
"nb_njit_func\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"@njit\n", | |
"def nb_njit_func(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"nb_njit_func\"\"\"\n", | |
" return x+y\n", | |
"\n", | |
"parse_func(nb_njit_func.py_func)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "c4fa4816-8715-41bc-b5b8-b77564e973e9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"<not available>\n", | |
"\n", | |
"argspec\n", | |
"<not available>\n", | |
"\n", | |
"docstring\n", | |
"nb_njit_func\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"parse_func(nb_njit_func)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cee6c6c0-8ddd-4313-bade-093a1a4eeb78", | |
"metadata": {}, | |
"source": [ | |
"# guvectorize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "88e84fe9-d621-49cc-b7e2-76b2629bf7b3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"{}\n", | |
"\n", | |
"argspec\n", | |
"FullArgSpec(args=['func'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={})\n", | |
"\n", | |
"docstring\n", | |
"None\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"@guvectorize\n", | |
"def nb_guvec_func(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"nb_guvec_func\"\"\"\n", | |
" return x+y\n", | |
"\n", | |
"parse_func(nb_guvec_func)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "228d26ad-4679-4ff9-97c4-d5326475eb3f", | |
"metadata": {}, | |
"source": [ | |
"# vectorize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "b7c946b8-edd7-4040-9b9a-cb7f6232c1f6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"<not available>\n", | |
"\n", | |
"argspec\n", | |
"<not available>\n", | |
"\n", | |
"docstring\n", | |
"nb_vec_func\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"@vectorize\n", | |
"def nb_vec_func(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"nb_vec_func\"\"\"\n", | |
" return x+y\n", | |
"\n", | |
"parse_func(nb_vec_func)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9dffe5bd-9ec8-4b26-b9a0-a701d3fe2ead", | |
"metadata": {}, | |
"source": [ | |
"# Workaround 1: parse before vectorize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "1d9ebb87-7191-4fe8-aac3-e33ce6839f60", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"annotations = {}\n", | |
"\n", | |
"def parse_annotations(f):\n", | |
"\n", | |
" inputs = []\n", | |
" outputs = []\n", | |
"\n", | |
" for k,v in get_type_hints(f, include_extras=True).items():\n", | |
"\n", | |
" datatype, component = get_args(v)\n", | |
"\n", | |
" if k == 'return':\n", | |
" outputs.append(component)\n", | |
" else:\n", | |
" inputs.append(component)\n", | |
"\n", | |
" annotations[f.__name__] = (inputs, outputs)\n", | |
" \n", | |
" return f" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "ab5bbc73-e1e9-4272-b121-f7e869571ef2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'nb_vec_func': (['custom_x', 'custom_y'], ['custom_z'])}" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"@vectorize\n", | |
"@parse_annotations\n", | |
"def nb_vec_func(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"nb_vec_func\"\"\"\n", | |
" return x+y\n", | |
"\n", | |
"annotations" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ecaaa3f9-6996-4477-bb9c-bbe26ae3090e", | |
"metadata": {}, | |
"source": [ | |
"# Workaround 2: wrap vectorized decorator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "eb79cfc8-498e-4778-b7e6-e27aa1fba15d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def wrap_vectorize(*args, **kwargs):\n", | |
"\n", | |
" if len(args) == 1 and not kwargs and callable(args[0]):\n", | |
" func, = args\n", | |
" args = tuple()\n", | |
" has_args = False\n", | |
" else:\n", | |
" has_args = True\n", | |
"\n", | |
" def outer(func):\n", | |
" \n", | |
" wrapped_vec = vectorize(*args, **kwargs)(func)\n", | |
" functools.update_wrapper(wrapped_vec, func)\n", | |
" \n", | |
" return wrapped_vec\n", | |
"\n", | |
" if has_args:\n", | |
" return outer\n", | |
" else:\n", | |
" return outer(func)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "25abf158-7449-4ea1-8292-1830fb023969", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@wrap_vectorize([\n", | |
" 'float32(float32, uint8)', \n", | |
" 'float64(float64, uint8)',\n", | |
"], fastmath=True)\n", | |
"def nb_vec_func_wrapped(x: x_type, y: y_type) -> z_type:\n", | |
" \"\"\"nb_vec_func\"\"\"\n", | |
" return x+y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "2d4f28d4-79f4-4c5e-9603-54285ab5a1f9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"type_hints\n", | |
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n", | |
"\n", | |
"argspec\n", | |
"<not available>\n", | |
"\n", | |
"docstring\n", | |
"nb_vec_func\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"parse_func(nb_vec_func_wrapped)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "65d5b2f8-ec7c-42d7-b528-211995188d2b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"numba.np.ufunc.dufunc.DUFunc" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"type(nb_vec_func_wrapped)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "e0b2923d-8897-4d45-a23d-40a908ca9321", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<function __main__.nb_vec_func_wrapped(x: Annotated[Any, 'custom_x'], y: Annotated[Any, 'custom_y']) -> Annotated[Any, 'custom_z']>" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nb_vec_func_wrapped.__wrapped__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "c53c0abf-feba-486d-8f9d-d0415d197c1b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[-0.14655443, 1.84584324],\n", | |
" [ 1.62229537, 1.32742006]])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# call function with correct datatypes [float, int]\n", | |
"nb_vec_func_wrapped(np.random.randn(2,2), 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "aa3bcc2b-d276-4561-8010-9a0de5832853", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "ufunc 'nb_vec_func_wrapped' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_2716/3192411579.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m# call function with incorrect datatypes [float, float], should raise TypeError\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mnb_vec_func_wrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2.5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[1;31mTypeError\u001b[0m: ufunc 'nb_vec_func_wrapped' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''" | |
] | |
} | |
], | |
"source": [ | |
"# call function with incorrect datatypes [float, float]\n", | |
"# should raise TypeError\n", | |
"nb_vec_func_wrapped(np.random.randn(2,2), 2.5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "10694760-ff40-49c4-a9cd-2b65d6f6da79", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"state": {}, | |
"version_major": 2, | |
"version_minor": 0 | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment