Skip to content

Instantly share code, notes, and snippets.

@luk-f-a
Last active January 17, 2020 09:28
Show Gist options
  • Save luk-f-a/6d0eef6534a8aae27a6650093321d833 to your computer and use it in GitHub Desktop.
Save luk-f-a/6d0eef6534a8aae27a6650093321d833 to your computer and use it in GitHub Desktop.
Class overload proposal for numba-scipy
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Current state\n",
"\n",
"Overloading a python class requires several steps, dealing directly with the typing, lowering, boxing and unboxing\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Creating a new Numba type**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from numba import types\n",
"\n",
"class IntervalType(types.Type):\n",
" def __init__(self):\n",
" super(IntervalType, self).__init__(name='Interval')\n",
"\n",
"interval_type = IntervalType()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Type inference for Python values**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from numba.extending import typeof_impl\n",
"\n",
"@typeof_impl.register(Interval)\n",
"def typeof_index(val, c):\n",
" return interval_type"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Type inference for operations**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from numba.extending import type_callable\n",
"\n",
"@type_callable(Interval)\n",
"def type_interval(context):\n",
" def typer(lo, hi):\n",
" if isinstance(lo, types.Float) and isinstance(hi, types.Float):\n",
" return interval_type\n",
" return typer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Defining the data model**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@register_model(IntervalType)\n",
"class IntervalModel(models.StructModel):\n",
" def __init__(self, dmm, fe_type):\n",
" members = [\n",
" ('lo', types.float64),\n",
" ('hi', types.float64),\n",
" ]\n",
" models.StructModel.__init__(self, dmm, fe_type, members)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exposing data model attributes**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"make_attribute_wrapper(IntervalType, 'lo', 'lo')\n",
"make_attribute_wrapper(IntervalType, 'hi', 'hi')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exposing a property**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@overload_attribute(IntervalType, \"width\")\n",
"def get_width(interval):\n",
" def getter(interval):\n",
" return interval.hi - interval.lo\n",
" return getter"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Implementing the constructor**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@lower_builtin(Interval, types.Float, types.Float)\n",
"def impl_interval(context, builder, sig, args):\n",
" typ = sig.return_type\n",
" lo, hi = args\n",
" interval = cgutils.create_struct_proxy(typ)(context, builder)\n",
" interval.lo = lo\n",
" interval.hi = hi\n",
" return interval._getvalue()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Problem \n",
"\n",
"These steps require advanced knowledge of Numba. `Numba-scipy` requires hundreds of overloads so we need to make it more accessible\n",
"in order to reach a wider pool of contributors."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Proposal using Jitclass"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import numba\n",
"from numba import types\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There is a certain Python class which we want to \"overload\" (provide a version that can run in jitted code including jit-transparency in references)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class Interval(object):\n",
" \"\"\"\n",
" A half-open interval on the real number line.\n",
" \"\"\"\n",
" def __init__(self, lo, hi):\n",
" self.lo = lo\n",
" self.hi = hi\n",
"\n",
" def __repr__(self):\n",
" return 'Interval(%f, %f)' % (self.lo, self.hi)\n",
"\n",
" @property\n",
" def width(self):\n",
" return self.hi - self.lo"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**First step**: the user would provide a jitclass that has the desired behaviour.\n",
"\n",
"In this case it's identical to the original, but for more complex objects it could be subset of the original pure-Python implementation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"@numba.jitclass(spec= [\n",
" ('lo', types.float64),\n",
" ('hi', types.float64),\n",
" ])\n",
"class IntervalJit(object):\n",
" \"\"\"\n",
" A half-open interval on the real number line.\n",
" \"\"\"\n",
" def __init__(self, lo, hi):\n",
" self.lo = lo\n",
" self.hi = hi\n",
"\n",
" def __repr__(self):\n",
" return 'Interval(%f, %f)' % (self.lo, self.hi)\n",
"\n",
" @property\n",
" def width(self):\n",
" return self.hi - self.lo"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Second step**: declare the Jitclass as overloading the Python class.\n",
"\n",
"`overload_pyclass` to be provided by `numba.extending` or `numba_scipy.extending`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"overload_pyclass(Interval, IntervalJit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Third step**: None. Sit back and enjoy your overloaded `Interval`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Additions to** `numba.extending` or `numba_scipy.extending` **(all in very alpha state below)**. \n",
"\n",
"Not everything is new code, some are existing functions that I have brought to the notebook"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def overload_pyclass1(pyclass, jitclass):\n",
" from numba.extending import typeof_impl\n",
" @typeof_impl.register(pyclass)\n",
" def typeof_index(val, c):\n",
" return jitclass.class_type.instance_type\n",
" \n",
" \n",
"def overload_pyclass2(pyclass, jitclass):\n",
" from numba.targets.registry import cpu_target\n",
" # Register resolution of the class object\n",
" typingctx = cpu_target.typing_context\n",
" typingctx.insert_global(pyclass, jitclass.class_type)\n",
"\n",
"\n",
" \n",
" \n",
"def _add_linking_libs(context, call):\n",
" \"\"\"\n",
" Add the required libs for the callable to allow inlining.\n",
" \"\"\"\n",
" libs = getattr(call, \"libs\", ())\n",
" if libs:\n",
" context.add_linking_libs(libs)\n",
" \n",
"def imp_dtor(context, module, instance_type):\n",
" from llvmlite import ir as llvmir\n",
" llvoidptr = context.get_value_type(types.voidptr)\n",
" llsize = context.get_value_type(types.uintp)\n",
" dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),\n",
" [llvoidptr, llsize, llvoidptr])\n",
"\n",
" fname = \"_Dtor.{0}\".format(instance_type.name)\n",
" dtor_fn = module.get_or_insert_function(dtor_ftype,\n",
" name=fname)\n",
" if dtor_fn.is_declaration:\n",
" # Define\n",
" builder = llvmir.IRBuilder(dtor_fn.append_basic_block())\n",
"\n",
" alloc_fe_type = instance_type.get_data_type()\n",
" alloc_type = context.get_value_type(alloc_fe_type)\n",
"\n",
" ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())\n",
" data = context.make_helper(builder, alloc_fe_type, ref=ptr)\n",
"\n",
" context.nrt.decref(builder, alloc_fe_type, data._getvalue())\n",
"\n",
" builder.ret_void()\n",
"\n",
" return dtor_fn\n",
"\n",
"def unbox_pyclass4(pyclass, jitclass):\n",
" from numba.extending import unbox, NativeValue\n",
" from numba import cgutils \n",
" from numba.pythonapi import _unboxers\n",
" del _unboxers.functions[types.ClassInstanceType]\n",
"\n",
" @unbox(types.ClassInstanceType)\n",
" def unbox_interval(typ, obj, c):\n",
" \"\"\"\n",
" Convert a Interval object to a native interval structure.\n",
" \"\"\"\n",
" obj_list = []\n",
" type_inst_list = []\n",
" for attr_name, attr_typ in typ.struct.items():\n",
" obj_list.append(c.pyapi.object_getattr_string(obj, attr_name))\n",
" type_inst_list.append(attr_typ)\n",
" \n",
" type_inst_list = tuple(type_inst_list)\n",
" \n",
" # Allocate the instance\n",
" inst_typ = typ\n",
" context = c.context\n",
" builder = c.builder\n",
" alloc_type = context.get_data_type(inst_typ.get_data_type())\n",
" alloc_size = context.get_abi_sizeof(alloc_type)\n",
"\n",
" meminfo = context.nrt.meminfo_alloc_dtor(\n",
" builder,\n",
" context.get_constant(types.uintp, alloc_size),\n",
" imp_dtor(context, builder.module, inst_typ),\n",
" )\n",
" data_pointer = context.nrt.meminfo_data(builder, meminfo)\n",
" data_pointer = builder.bitcast(data_pointer,\n",
" alloc_type.as_pointer())\n",
"\n",
" # Nullify all data\n",
" builder.store(cgutils.get_null_value(alloc_type),\n",
" data_pointer)\n",
"\n",
" inst_struct = context.make_helper(builder, inst_typ)\n",
" inst_struct.meminfo = meminfo\n",
" inst_struct.data = data_pointer\n",
"\n",
" #TODO: fill attributes with actual values\n",
" #IDEA: instead of doing automatically, call an __unbox__ method in the jitclass\n",
" # to allow user customization\n",
"\n",
" # Prepare return value\n",
" ret = inst_struct._getvalue()\n",
" \n",
" return NativeValue(ret, is_error=c.pyapi.c_api_error())\n",
" \n",
" \n",
"\n",
"def overload_pyclass(pyclass, jitclass):\n",
" overload_pyclass1(pyclass, jitclass)\n",
" overload_pyclass2(pyclass, jitclass)\n",
" unbox_pyclass4(pyclass, jitclass)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tests"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Create Test1: True\n",
"Create Test2: True\n",
"Unbox Test1 \n",
" \tUnboxing: True \n",
" \tMember transfer: False\n",
"Unbox Test2 \n",
" \tUnboxing: True \n",
" \tMember transfer: False\n",
"Property Test1 \n",
" \tUnboxing: True \n",
" \tProperty calculation: False\n",
"Box Test1 \n",
" \tBoxing: False\n"
]
}
],
"source": [
"from numba import njit\n",
"\n",
"@njit\n",
"def create_interval1():\n",
" a = Interval(2.1, 3.1)\n",
" return a.lo, a.width\n",
"\n",
"temp = create_interval1()\n",
"print(\"Create Test1: \", temp==(2.1, 1.0))\n",
"\n",
"@njit\n",
"def create_interval2(i, j):\n",
" a = Interval(i, j)\n",
" return a.lo, a.width\n",
"\n",
"\n",
"temp = create_interval2(4.1,5.6)\n",
"print(\"Create Test2: \", temp==(4.1, 1.5))\n",
"\n",
"\n",
"inter = Interval(2.1, 3.1)\n",
"\n",
"@njit\n",
"def inside_interval1(interval):\n",
" return interval.lo \n",
"\n",
"temp = inside_interval1(inter)\n",
"print('Unbox Test1', \"\\n \\tUnboxing: True\", \"\\n \\tMember transfer: \" + str(temp==2.1))\n",
"\n",
"@njit\n",
"def inside_interval2(interval, x):\n",
" return interval.lo <= x < interval.hi\n",
"\n",
"temp = inside_interval2(inter, 2.5)\n",
"print('Unbox Test2', \"\\n \\tUnboxing: True\", \"\\n \\tMember transfer: \" + str(temp))\n",
"\n",
"\n",
"@njit\n",
"def interval_width(interval):\n",
" return interval.width\n",
"\n",
"temp = interval_width(inter)\n",
"print('Property Test1', \"\\n \\tUnboxing: True\", \"\\n \\tProperty calculation: \" + str(temp==1.0))\n",
"\n",
"\n",
"@njit\n",
"\n",
"def sum_intervals(i, j):\n",
" #return Interval(i.lo + j.lo, i.hi + j.hi)\n",
" a = Interval(i.lo + j.lo, i.hi + j.hi)\n",
" return \"success\"\n",
"\n",
"\n",
"try:\n",
" temp = sum_intervals(inter, inter)\n",
"except:\n",
" print('Box Test1', \"\\n \\tBoxing: False\")\n",
"else:\n",
" print('Box Test1', \"\\n \\tBoxing: True\", \"\\n \\tAttributes: \" + str(temp.lo==4.2) + \" \" + str(temp.hi==6.2))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (latest_versions)",
"language": "python",
"name": "latest_versions"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment