Skip to content

Instantly share code, notes, and snippets.

@jcrist
Created January 19, 2016 00:43
Show Gist options
  • Save jcrist/fddcbb12f3c74748aea5 to your computer and use it in GitHub Desktop.
Save jcrist/fddcbb12f3c74748aea5 to your computer and use it in GitHub Desktop.
Generated functions in numba
In [1]: from generated import generated

In [2]: import numba as nb

In [3]: @generated(nopython=True)
   ...: def foo(x, y):
   ...:     if isinstance(x, nb.types.Integer):
   ...:         return lambda x, y: x + y
   ...:     else:
   ...:         return lambda x, y: x - y
   ...:

In [4]: foo(1, 2)
Out[4]: 3

In [5]: foo(1., 2)
Out[5]: -1.0

In [6]: @nb.jit(nopython=True)
   ...: def bar(x, y):
   ...:     return foo(x, y) + 4
   ...:

In [7]: bar(1, 2)
Out[7]: 7

In [8]: bar(1., 2)
Out[8]: 3.0
from numba.dispatcher import Dispatcher
from numba import compiler, sigutils, jit
from numba.targets.registry import CPUTarget, dispatcher_registry
from dill import dumps, loads
class GeneratedFunction(Dispatcher):
"""Base class for generated functions"""
def compile(self, sig):
# TODO: Mostly copied from numba. Look in to how to refactor numba to
# remove the need to duplicate code.
with self._compile_lock:
args, return_type = sigutils.normalize_signature(sig)
existing = self.overloads.get(tuple(args))
if existing is not None:
return existing
cres = self._cache.load_overload(sig, self.targetctx)
if cres is not None:
if not cres.objectmode and not cres.interpmode:
self.targetctx.insert_user_function(cres.entry_point,
cres.fndesc,
[cres.library])
self.add_overload(cres)
return cres.entry_point
flags = compiler.Flags()
self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
# Generate expression
f = self.py_func(*args)
cres = compiler.compile_extra(self.typingctx, self.targetctx, f,
args=args, return_type=return_type,
flags=flags, locals=self.locals)
if cres.typing_error is not None and not flags.enable_pyobject:
raise cres.typing_error
self.add_overload(cres)
self._cache.save_overload(sig, cres)
return cres.entry_point
def __reduce__(self):
if self._can_compile:
sigs = []
else:
sigs = [cr.signature for cr in self._compileinfos.values()]
return (_rebuild, (self.__class__, dumps(self.py_func, recurse=True),
self.locals, self.targetoptions, self._can_compile, sigs))
def _rebuild(cls, func_reduced, locals, targetoptions, can_compile, sigs):
"""Rebuild a Dispatcher instance after it was __reduce__'d."""
py_func = loads(func_reduced)
self = cls(py_func, locals, targetoptions)
for sig in sigs:
self.compile(sig)
self._can_compile = can_compile
return self
class GeneratedCPUDispatcher(GeneratedFunction):
targetdescr = CPUTarget()
def generated(*args, **kwargs):
kwargs['target'] = 'gen_' + kwargs.get('target', 'cpu')
return jit(*args, **kwargs)
dispatcher_registry['gen_cpu'] = GeneratedCPUDispatcher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment