Created
February 12, 2020 20:33
-
-
Save agoose77/ba52400980b106e110b1f6342e1a8397 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Numba Jax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Welcome to JupyROOT 6.19/01\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"from jaxlib import xla_client\n", | |
"import numba\n", | |
"import ROOT\n", | |
"import jax\n", | |
"from jax import core\n", | |
"from jax import abstract_arrays\n", | |
"from jax import xla\n", | |
"from numba import cfunc, types, carray" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext Cython" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%cython\n", | |
"from cpython.long cimport PyLong_FromVoidPtr, PyLong_AsVoidPtr\n", | |
"from libc.stdint cimport int32_t, int64_t\n", | |
"import ctypes\n", | |
"\n", | |
"ctypedef void (*numba_ptr)(float*, float*, float*, int32_t)\n", | |
"\n", | |
"cdef numba_ptr numba_cfunc;\n", | |
"\n", | |
"def set_numba_f(numba_f):\n", | |
" global numba_cfunc\n", | |
" cdef size_t address = ctypes.addressof(numba_f.ctypes)\n", | |
" numba_cfunc = (<numba_ptr*>address)[0]\n", | |
" \n", | |
"# call_numba_f(out, data) where data = [a, b, n, *f] which gives f(a,b,out,n)\n", | |
"cdef call_numba_f(void* out, void** data): \n", | |
" global numba_cfunc\n", | |
" numba_cfunc(\n", | |
" <float*>(data[0]), \n", | |
" <float*>(data[1]), \n", | |
" <float*>out, \n", | |
" (<int32_t*>(data[2]))[0]\n", | |
" )\n", | |
" \n", | |
" \n", | |
"def get_call_numba_f():\n", | |
" return PyLong_FromVoidPtr(&call_numba_f)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Store pointer to numba cfunc with (x, y, out, n) signature" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%cython\n", | |
"from cpython.pycapsule cimport PyCapsule_New\n", | |
"from cpython.long cimport PyLong_AsVoidPtr\n", | |
"from jaxlib import xla_client\n", | |
"\n", | |
"\n", | |
"cpdef register_cpu_custom_call_target(fn_name, long ptr):\n", | |
" cdef void* fn = PyLong_AsVoidPtr(ptr);\n", | |
" cdef const char* name = \"xla._CUSTOM_CALL_TARGET\"\n", | |
" xla_client.register_cpu_custom_call_target(\n", | |
" fn_name, PyCapsule_New(fn, name, NULL)\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"numba_func_name = \"numba_func\"\n", | |
"register_cpu_custom_call_target(numba_func_name, get_call_numba_f())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"numba_p = core.Primitive(\"numba\") # Create the primitive" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def numba_prim(x, y):\n", | |
" return numba_p.bind(x, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@numba_p.def_impl\n", | |
"def numba_impl(x, y):\n", | |
" return x + y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@numba_p.def_abstract_eval\n", | |
"def numba_abstract_eval(x, y):\n", | |
" assert x.shape == y.shape\n", | |
" return abstract_arrays.ShapedArray(x.shape, x.dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def numba_xla_translation(c, xc, yc):\n", | |
" n_dims = c.GetShape(xc).dimensions()[0]\n", | |
" return c.CustomCall(numba_func_name,\n", | |
" operands=(\n", | |
" xc,\n", | |
" yc,\n", | |
" c.ConstantS32Scalar(n_dims)\n", | |
" ),\n", | |
" shape_with_layout=c.GetShape(xc),\n", | |
" operand_shapes_with_layout=(\n", | |
" c.GetShape(xc),\n", | |
" c.GetShape(yc),\n", | |
" xla_client.Shape.array_shape(np.dtype(np.int32), (), ()),\n", | |
" )\n", | |
" )\n", | |
"\n", | |
"# Now we register the XLA compilation rule with JAX\n", | |
"xla.backend_specific_translations['cpu'][numba_p] = numba_xla_translation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create numba function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"c_sig = types.void(types.CPointer(types.float32),\n", | |
" types.CPointer(types.float32),\n", | |
" types.CPointer(types.float32),\n", | |
" types.intc)\n", | |
"\n", | |
"@cfunc(c_sig)\n", | |
"def numba_add_n(a, b, out, n):\n", | |
" a_array = carray(a, (n,))\n", | |
" b_array = carray(b, (n,))\n", | |
" out_array = carray(out, (n,))\n", | |
" \n", | |
" for i in range(n):\n", | |
" out_array[i] = a_array[i] + b_array[i]\n", | |
"\n", | |
"@cfunc(c_sig)\n", | |
"def numba_mul_n(a, b, out, n):\n", | |
" a_array = carray(a, (n,))\n", | |
" b_array = carray(b, (n,))\n", | |
" out_array = carray(out, (n,))\n", | |
" \n", | |
" for i in range(n):\n", | |
" out_array[i] = a_array[i] * b_array[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Register `add` global callable" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_numba_f(numba_add_n)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Execute and observe result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = np.array([1., 2.])\n", | |
"b = np.array([9., 2.])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/angus/.pyenv/versions/3.8.1/envs/nuclear-phd/lib/python3.8/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray([10., 4.], dtype=float32)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jax.jit(numba_prim)(a, b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Register `mul` global callable" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_numba_f(numba_mul_n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray([9., 4.], dtype=float32)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jax.jit(numba_prim)(a, b)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment