Skip to content

Instantly share code, notes, and snippets.

@agoose77
Created February 13, 2020 10:22
Show Gist options
  • Save agoose77/784cad0e4db5caaf818940af56d87919 to your computer and use it in GitHub Desktop.
Save agoose77/784cad0e4db5caaf818940af56d87919 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numba Jax"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Welcome to JupyROOT 6.19/01\n"
]
}
],
"source": [
"import numpy as np\n",
"from jaxlib import xla_client\n",
"import numba\n",
"import ROOT\n",
"import jax\n",
"from jax import core\n",
"from jax import abstract_arrays\n",
"from jax import xla\n",
"from numba import cfunc, types, carray\n",
"\n",
"from Cython.Build import cythonize\n",
"from distutils.core import Extension, Distribution\n",
"from distutils.command.build_ext import build_ext\n",
"from tempfile import mkdtemp\n",
"from os import path\n",
"\n",
"import importlib.util"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext Cython"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Store pointer to numba cfunc with (x, y, out, n) signature"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%%cython\n",
"from cpython.pycapsule cimport PyCapsule_New\n",
"from cpython.long cimport PyLong_AsVoidPtr\n",
"from jaxlib import xla_client\n",
"import ctypes\n",
"\n",
"\n",
"def register_cpu_custom_call_target(fn_name, ctypes_ptr):\n",
" ptr = ctypes.cast(ctypes_ptr, ctypes.c_void_p)\n",
" cdef void* fn = PyLong_AsVoidPtr(ptr.value);\n",
" cdef const char* name = \"xla._CUSTOM_CALL_TARGET\"\n",
" xla_client.register_cpu_custom_call_target(\n",
" fn_name, PyCapsule_New(fn, name, NULL)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"wrapper_header = \"\"\"\n",
"from cpython.long cimport PyLong_AsVoidPtr, PyLong_FromVoidPtr\n",
"from libc.stdint cimport int32_t, int64_t\n",
"import ctypes\n",
"\n",
"def get_ptr():\n",
" ptr_cls = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p))\n",
" return ptr_cls(<long>&delegate_numba) \n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ctypes_map = {\n",
" ctypes.c_float: 'float',\n",
" ctypes.c_double: 'double',\n",
" ctypes.c_short: 'short',\n",
" ctypes.c_ushort: 'ushort',\n",
" ctypes.c_int: 'int',\n",
" ctypes.c_uint: 'uint',\n",
" ctypes.c_long: 'long',\n",
" ctypes.c_ulong: 'ulong',\n",
" ctypes.c_longlong: 'longlong',\n",
" ctypes.c_ulonglong: 'ulonglong',\n",
" ctypes.c_size_t: 'size_t',\n",
" ctypes.c_ssize_t: 'ssize_t',\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def compile_cython_extension(code):\n",
" lib_dir = mkdtemp()\n",
" module_name = \"extension\"\n",
" pyx_path = path.join(lib_dir, module_name + \".pyx\")\n",
"\n",
" with open(pyx_path, 'w', encoding='utf-8') as f:\n",
" f.write(code)\n",
"\n",
" extension = Extension(\n",
" name=module_name,\n",
" sources=[pyx_path],\n",
" include_dirs=[],\n",
" library_dirs=[lib_dir],\n",
" extra_compile_args=[],\n",
" extra_link_args=[],\n",
" libraries=[],\n",
" language='c',\n",
" )\n",
"\n",
" import sys\n",
" extensions = cythonize([extension], language_level=min(3, sys.version_info[0]))\n",
"\n",
" dist = Distribution()\n",
" config_files = dist.find_config_files()\n",
" try:\n",
" config_files.remove('setup.cfg')\n",
" except ValueError:\n",
" pass\n",
" dist.parse_config_files(config_files)\n",
"\n",
" build_extension = build_ext(dist)\n",
" build_extension.finalize_options()\n",
" build_extension.build_lib = lib_dir\n",
" build_extension.extensions = extensions\n",
" build_extension.run()\n",
"\n",
" module_path = path.join(lib_dir, build_extension.get_ext_filename(module_name))\n",
" spec = importlib.util.spec_from_file_location(module_name, module_path)\n",
" module = importlib.util.module_from_spec(spec)\n",
" spec.loader.exec_module(module)\n",
" return module"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def generate_caller_code(numba_fn):\n",
" c_type_names = [ctypes_map[t._type_] if issubclass(t, ctypes._Pointer) else ctypes_map[t] for t in numba_fn.ctypes.argtypes]\n",
" c_primitive_flags = [not issubclass(t, ctypes._Pointer) for t in numba_fn.ctypes.argtypes]\n",
"\n",
" out_c_type, arg_c_types = c_type_names[0], c_type_names[1:]\n",
" arg_primitive_flags = c_primitive_flags[1:]\n",
" assert not c_primitive_flags[0]\n",
"\n",
" arg_names = [f\"arg_{i}\" for i in range(len(arg_c_types))]\n",
"\n",
" variable_declarations = [\n",
" f\"cdef {t}* {n} = <{t}*>data[{i}]\" for i, (n, t) in enumerate(zip(arg_names, arg_c_types))\n",
" ]\n",
"\n",
" arg_list = [(f'{n}[0]' if p else n) for n, p in zip(arg_names, arg_primitive_flags)]\n",
"\n",
" func_arg_spec = [f\"{t}\" if p else f\"{t}*\" for t, p in zip(arg_c_types, arg_primitive_flags)]\n",
"\n",
" NL = \"\\n \"\n",
" ARG = \", \"\n",
" code = f\"\"\" \n",
"{wrapper_header}\n",
"\n",
"# Function signature of numba func\n",
"ctypedef void (*func_ptr)({out_c_type}*, {ARG.join(func_arg_spec)});\n",
"\n",
"cdef delegate_numba(void* out, void** data):\n",
" # Unpack arguments\n",
" cdef {out_c_type}* result = <{out_c_type}*>out;\n",
" {NL.join(variable_declarations)}\n",
"\n",
" # Call func\n",
" cdef long addr = {numba_fn.address};\n",
" cdef func_ptr func = <func_ptr>(<void*>addr);\n",
" func(result, {ARG.join(arg_list)})\n",
" \"\"\"\n",
" return code\n",
"\n",
"_loaded_modules = []\n",
"def get_caller_ptr(numba_fn):\n",
" code = generate_caller_code(numba_fn)\n",
" module = compile_cython_extension(code)\n",
" _loaded_modules.append(module)\n",
" return module.get_ptr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create numba function"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"c_sig = types.void(types.CPointer(types.float32),\n",
" types.CPointer(types.float32),\n",
" types.CPointer(types.float32),\n",
" types.intc)\n",
"\n",
"@cfunc(c_sig)\n",
"def numba_add_n(out, a, b, n):\n",
" a_array = carray(a, (n,))\n",
" b_array = carray(b, (n,))\n",
" out_array = carray(out, (n,))\n",
" \n",
" for i in range(n):\n",
" out_array[i] = a_array[i] + b_array[i]\n",
"\n",
"@cfunc(c_sig)\n",
"def numba_mul_n(out, a, b, n):\n",
" a_array = carray(a, (n,))\n",
" b_array = carray(b, (n,))\n",
" out_array = carray(out, (n,))\n",
" \n",
" for i in range(n):\n",
" out_array[i] = a_array[i] * b_array[i]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compiling /tmp/tmps_jc8s2f/extension.pyx because it changed.\n",
"[1/1] Cythonizing /tmp/tmps_jc8s2f/extension.pyx\n"
]
}
],
"source": [
"numba_func_name = \"numba_func\"\n",
"register_cpu_custom_call_target(numba_func_name, get_caller_ptr(numba_add_n))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"numba_p = core.Primitive(\"numba\") # Create the primitive"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def numba_prim(x, y):\n",
" return numba_p.bind(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"@numba_p.def_impl\n",
"def numba_impl(x, y):\n",
" return x + y"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"@numba_p.def_abstract_eval\n",
"def numba_abstract_eval(x, y):\n",
" assert x.shape == y.shape\n",
" return abstract_arrays.ShapedArray(x.shape, x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def numba_xla_translation(c, xc, yc):\n",
" n_dims = c.GetShape(xc).dimensions()[0]\n",
" return c.CustomCall(numba_func_name,\n",
" operands=(\n",
" xc,\n",
" yc,\n",
" c.ConstantS32Scalar(n_dims)\n",
" ),\n",
" shape_with_layout=c.GetShape(xc),\n",
" operand_shapes_with_layout=(\n",
" c.GetShape(xc),\n",
" c.GetShape(yc),\n",
" xla_client.Shape.array_shape(np.dtype(np.int32), (), ()),\n",
" )\n",
" )\n",
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"xla.backend_specific_translations['cpu'][numba_p] = numba_xla_translation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Execute and observe result"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"a = np.array([1., 2.])\n",
"b = np.array([9., 2.])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/angus/.pyenv/versions/3.8.1/envs/nuclear-phd/lib/python3.8/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
},
{
"data": {
"text/plain": [
"DeviceArray([10., 4.], dtype=float32)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.jit(numba_prim)(a, b)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment