Skip to content

Instantly share code, notes, and snippets.

@RutgerK
Last active November 10, 2021 14:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RutgerK/6a8468a27fe3e7c1a520c36bf9eac344 to your computer and use it in GitHub Desktop.
Save RutgerK/6a8468a27fe3e7c1a520c36bf9eac344 to your computer and use it in GitHub Desktop.
Numba type hints
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9c658ddd-d40b-44c8-a01b-a16a659e7fa9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.54.1'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import typing\n",
"from typing import Annotated, Any, get_type_hints, get_args\n",
"\n",
"from numba import njit, guvectorize, vectorize, __version__\n",
"import numpy as np\n",
"import inspect\n",
"import functools\n",
"__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d1b6b656-7302-4b0e-a1ea-059ba1c1cf08",
"metadata": {},
"outputs": [],
"source": [
"def parse_func(f):\n",
"\n",
" try:\n",
" print('type_hints', get_type_hints(f, include_extras=True), sep='\\n', end='\\n\\n')\n",
" except:\n",
" print('type_hints', '<not available>', sep='\\n', end='\\n\\n')\n",
"\n",
"\n",
" try: \n",
" print('argspec', inspect.getfullargspec(f), sep='\\n', end='\\n\\n')\n",
" except:\n",
" print('argspec', '<not available>', sep='\\n', end='\\n\\n')\n",
"\n",
"\n",
" try:\n",
" print('docstring', f.__doc__, sep='\\n', end='\\n\\n')\n",
" except:\n",
" print('docstring', '<not available>', sep='\\n', end='\\n\\n')"
]
},
{
"cell_type": "markdown",
"id": "5a7b5d47-4a50-4552-be45-3ee3c9ee24dc",
"metadata": {},
"source": [
"# Pure Python"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0811c9e9-9c04-408d-a315-c7cb2f396e5a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n",
"\n",
"argspec\n",
"FullArgSpec(args=['x', 'y'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Annotated[typing.Any, 'custom_z'], 'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y']})\n",
"\n",
"docstring\n",
"py_func\n",
"\n"
]
}
],
"source": [
"x_type = Annotated[Any, \"custom_x\"]\n",
"y_type = Annotated[Any, \"custom_y\"]\n",
"z_type = Annotated[Any, \"custom_z\"]\n",
"\n",
"def py_func(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"py_func\"\"\"\n",
" return x+y\n",
"\n",
"parse_func(py_func)"
]
},
{
"cell_type": "markdown",
"id": "95fcaa4a-243d-492d-87e6-c88aa343572b",
"metadata": {},
"source": [
"# njit"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7dfa1432-c295-4e92-9150-f053f3ac74ef",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n",
"\n",
"argspec\n",
"FullArgSpec(args=['x', 'y'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Annotated[typing.Any, 'custom_z'], 'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y']})\n",
"\n",
"docstring\n",
"nb_njit_func\n",
"\n"
]
}
],
"source": [
"@njit\n",
"def nb_njit_func(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"nb_njit_func\"\"\"\n",
" return x+y\n",
"\n",
"parse_func(nb_njit_func.py_func)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c4fa4816-8715-41bc-b5b8-b77564e973e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"<not available>\n",
"\n",
"argspec\n",
"<not available>\n",
"\n",
"docstring\n",
"nb_njit_func\n",
"\n"
]
}
],
"source": [
"parse_func(nb_njit_func)"
]
},
{
"cell_type": "markdown",
"id": "cee6c6c0-8ddd-4313-bade-093a1a4eeb78",
"metadata": {},
"source": [
"# guvectorize"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "88e84fe9-d621-49cc-b7e2-76b2629bf7b3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"{}\n",
"\n",
"argspec\n",
"FullArgSpec(args=['func'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={})\n",
"\n",
"docstring\n",
"None\n",
"\n"
]
}
],
"source": [
"@guvectorize\n",
"def nb_guvec_func(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"nb_guvec_func\"\"\"\n",
" return x+y\n",
"\n",
"parse_func(nb_guvec_func)"
]
},
{
"cell_type": "markdown",
"id": "228d26ad-4679-4ff9-97c4-d5326475eb3f",
"metadata": {},
"source": [
"# vectorize"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b7c946b8-edd7-4040-9b9a-cb7f6232c1f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"<not available>\n",
"\n",
"argspec\n",
"<not available>\n",
"\n",
"docstring\n",
"nb_vec_func\n",
"\n"
]
}
],
"source": [
"@vectorize\n",
"def nb_vec_func(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"nb_vec_func\"\"\"\n",
" return x+y\n",
"\n",
"parse_func(nb_vec_func)"
]
},
{
"cell_type": "markdown",
"id": "9dffe5bd-9ec8-4b26-b9a0-a701d3fe2ead",
"metadata": {},
"source": [
"# Workaround 1: parse before vectorize"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1d9ebb87-7191-4fe8-aac3-e33ce6839f60",
"metadata": {},
"outputs": [],
"source": [
"annotations = {}\n",
"\n",
"def parse_annotations(f):\n",
"\n",
" inputs = []\n",
" outputs = []\n",
"\n",
" for k,v in get_type_hints(f, include_extras=True).items():\n",
"\n",
" datatype, component = get_args(v)\n",
"\n",
" if k == 'return':\n",
" outputs.append(component)\n",
" else:\n",
" inputs.append(component)\n",
"\n",
" annotations[f.__name__] = (inputs, outputs)\n",
" \n",
" return f"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ab5bbc73-e1e9-4272-b121-f7e869571ef2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'nb_vec_func': (['custom_x', 'custom_y'], ['custom_z'])}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@vectorize\n",
"@parse_annotations\n",
"def nb_vec_func(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"nb_vec_func\"\"\"\n",
" return x+y\n",
"\n",
"annotations"
]
},
{
"cell_type": "markdown",
"id": "ecaaa3f9-6996-4477-bb9c-bbe26ae3090e",
"metadata": {},
"source": [
"# Workaround 2: wrap vectorized decorator"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "eb79cfc8-498e-4778-b7e6-e27aa1fba15d",
"metadata": {},
"outputs": [],
"source": [
"def wrap_vectorize(*args, **kwargs):\n",
"\n",
" if len(args) == 1 and not kwargs and callable(args[0]):\n",
" func, = args\n",
" args = tuple()\n",
" has_args = False\n",
" else:\n",
" has_args = True\n",
"\n",
" def outer(func):\n",
" \n",
" wrapped_vec = vectorize(*args, **kwargs)(func)\n",
" functools.update_wrapper(wrapped_vec, func)\n",
" \n",
" return wrapped_vec\n",
"\n",
" if has_args:\n",
" return outer\n",
" else:\n",
" return outer(func)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "25abf158-7449-4ea1-8292-1830fb023969",
"metadata": {},
"outputs": [],
"source": [
"@wrap_vectorize([\n",
" 'float32(float32, uint8)', \n",
" 'float64(float64, uint8)',\n",
"], fastmath=True)\n",
"def nb_vec_func_wrapped(x: x_type, y: y_type) -> z_type:\n",
" \"\"\"nb_vec_func\"\"\"\n",
" return x+y"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2d4f28d4-79f4-4c5e-9603-54285ab5a1f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type_hints\n",
"{'x': typing.Annotated[typing.Any, 'custom_x'], 'y': typing.Annotated[typing.Any, 'custom_y'], 'return': typing.Annotated[typing.Any, 'custom_z']}\n",
"\n",
"argspec\n",
"<not available>\n",
"\n",
"docstring\n",
"nb_vec_func\n",
"\n"
]
}
],
"source": [
"parse_func(nb_vec_func_wrapped)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "65d5b2f8-ec7c-42d7-b528-211995188d2b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numba.np.ufunc.dufunc.DUFunc"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(nb_vec_func_wrapped)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e0b2923d-8897-4d45-a23d-40a908ca9321",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.nb_vec_func_wrapped(x: Annotated[Any, 'custom_x'], y: Annotated[Any, 'custom_y']) -> Annotated[Any, 'custom_z']>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb_vec_func_wrapped.__wrapped__"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c53c0abf-feba-486d-8f9d-d0415d197c1b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.14655443, 1.84584324],\n",
" [ 1.62229537, 1.32742006]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# call function with correct datatypes [float, int]\n",
"nb_vec_func_wrapped(np.random.randn(2,2), 2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "aa3bcc2b-d276-4561-8010-9a0de5832853",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "ufunc 'nb_vec_func_wrapped' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_2716/3192411579.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m# call function with incorrect datatypes [float, float], should raise TypeError\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mnb_vec_func_wrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2.5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m: ufunc 'nb_vec_func_wrapped' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''"
]
}
],
"source": [
"# call function with incorrect datatypes [float, float]\n",
"# should raise TypeError\n",
"nb_vec_func_wrapped(np.random.randn(2,2), 2.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10694760-ff40-49c4-a9cd-2b65d6f6da79",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment