Skip to content

Instantly share code, notes, and snippets.

@agoose77
Created February 12, 2020 20:33
Show Gist options
  • Save agoose77/ba52400980b106e110b1f6342e1a8397 to your computer and use it in GitHub Desktop.
Save agoose77/ba52400980b106e110b1f6342e1a8397 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numba Jax"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Welcome to JupyROOT 6.19/01\n"
]
}
],
"source": [
"import numpy as np\n",
"from jaxlib import xla_client\n",
"import numba\n",
"import ROOT\n",
"import jax\n",
"from jax import core\n",
"from jax import abstract_arrays\n",
"from jax import xla\n",
"from numba import cfunc, types, carray"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext Cython"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%%cython\n",
"from cpython.long cimport PyLong_FromVoidPtr, PyLong_AsVoidPtr\n",
"from libc.stdint cimport int32_t, int64_t\n",
"import ctypes\n",
"\n",
"ctypedef void (*numba_ptr)(float*, float*, float*, int32_t)\n",
"\n",
"cdef numba_ptr numba_cfunc;\n",
"\n",
"def set_numba_f(numba_f):\n",
" global numba_cfunc\n",
" cdef size_t address = ctypes.addressof(numba_f.ctypes)\n",
" numba_cfunc = (<numba_ptr*>address)[0]\n",
" \n",
"# call_numba_f(out, data) where data = [a, b, n, *f] which gives f(a,b,out,n)\n",
"cdef call_numba_f(void* out, void** data): \n",
" global numba_cfunc\n",
" numba_cfunc(\n",
" <float*>(data[0]), \n",
" <float*>(data[1]), \n",
" <float*>out, \n",
" (<int32_t*>(data[2]))[0]\n",
" )\n",
" \n",
" \n",
"def get_call_numba_f():\n",
" return PyLong_FromVoidPtr(&call_numba_f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Store pointer to numba cfunc with (x, y, out, n) signature"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"%%cython\n",
"from cpython.pycapsule cimport PyCapsule_New\n",
"from cpython.long cimport PyLong_AsVoidPtr\n",
"from jaxlib import xla_client\n",
"\n",
"\n",
"cpdef register_cpu_custom_call_target(fn_name, long ptr):\n",
" cdef void* fn = PyLong_AsVoidPtr(ptr);\n",
" cdef const char* name = \"xla._CUSTOM_CALL_TARGET\"\n",
" xla_client.register_cpu_custom_call_target(\n",
" fn_name, PyCapsule_New(fn, name, NULL)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"numba_func_name = \"numba_func\"\n",
"register_cpu_custom_call_target(numba_func_name, get_call_numba_f())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"numba_p = core.Primitive(\"numba\") # Create the primitive"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def numba_prim(x, y):\n",
" return numba_p.bind(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@numba_p.def_impl\n",
"def numba_impl(x, y):\n",
" return x + y"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"@numba_p.def_abstract_eval\n",
"def numba_abstract_eval(x, y):\n",
" assert x.shape == y.shape\n",
" return abstract_arrays.ShapedArray(x.shape, x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def numba_xla_translation(c, xc, yc):\n",
" n_dims = c.GetShape(xc).dimensions()[0]\n",
" return c.CustomCall(numba_func_name,\n",
" operands=(\n",
" xc,\n",
" yc,\n",
" c.ConstantS32Scalar(n_dims)\n",
" ),\n",
" shape_with_layout=c.GetShape(xc),\n",
" operand_shapes_with_layout=(\n",
" c.GetShape(xc),\n",
" c.GetShape(yc),\n",
" xla_client.Shape.array_shape(np.dtype(np.int32), (), ()),\n",
" )\n",
" )\n",
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"xla.backend_specific_translations['cpu'][numba_p] = numba_xla_translation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create numba function"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"c_sig = types.void(types.CPointer(types.float32),\n",
" types.CPointer(types.float32),\n",
" types.CPointer(types.float32),\n",
" types.intc)\n",
"\n",
"@cfunc(c_sig)\n",
"def numba_add_n(a, b, out, n):\n",
" a_array = carray(a, (n,))\n",
" b_array = carray(b, (n,))\n",
" out_array = carray(out, (n,))\n",
" \n",
" for i in range(n):\n",
" out_array[i] = a_array[i] + b_array[i]\n",
"\n",
"@cfunc(c_sig)\n",
"def numba_mul_n(a, b, out, n):\n",
" a_array = carray(a, (n,))\n",
" b_array = carray(b, (n,))\n",
" out_array = carray(out, (n,))\n",
" \n",
" for i in range(n):\n",
" out_array[i] = a_array[i] * b_array[i]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Register `add` global callable"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"set_numba_f(numba_add_n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Execute and observe result"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"a = np.array([1., 2.])\n",
"b = np.array([9., 2.])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/angus/.pyenv/versions/3.8.1/envs/nuclear-phd/lib/python3.8/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
},
{
"data": {
"text/plain": [
"DeviceArray([10., 4.], dtype=float32)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.jit(numba_prim)(a, b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Register `mul` global callable"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"set_numba_f(numba_mul_n)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([9., 4.], dtype=float32)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.jit(numba_prim)(a, b)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment