In [1]: from generated import generated
In [2]: import numba as nb
In [3]: @generated(nopython=True)
...: def foo(x, y):
...: if isinstance(x, nb.types.Integer):
...: return lambda x, y: x + y
...: else:
...: return lambda x, y: x - y
...:
In [4]: foo(1, 2)
Out[4]: 3
In [5]: foo(1., 2)
Out[5]: -1.0
In [6]: @nb.jit(nopython=True)
...: def bar(x, y):
...: return foo(x, y) + 4
...:
In [7]: bar(1, 2)
Out[7]: 7
In [8]: bar(1., 2)
Out[8]: 3.0
Created
January 19, 2016 00:43
-
-
Save jcrist/fddcbb12f3c74748aea5 to your computer and use it in GitHub Desktop.
Generated functions in numba
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba.dispatcher import Dispatcher | |
from numba import compiler, sigutils, jit | |
from numba.targets.registry import CPUTarget, dispatcher_registry | |
from dill import dumps, loads | |
class GeneratedFunction(Dispatcher): | |
"""Base class for generated functions""" | |
def compile(self, sig): | |
# TODO: Mostly copied from numba. Look in to how to refactor numba to | |
# remove the need to duplicate code. | |
with self._compile_lock: | |
args, return_type = sigutils.normalize_signature(sig) | |
existing = self.overloads.get(tuple(args)) | |
if existing is not None: | |
return existing | |
cres = self._cache.load_overload(sig, self.targetctx) | |
if cres is not None: | |
if not cres.objectmode and not cres.interpmode: | |
self.targetctx.insert_user_function(cres.entry_point, | |
cres.fndesc, | |
[cres.library]) | |
self.add_overload(cres) | |
return cres.entry_point | |
flags = compiler.Flags() | |
self.targetdescr.options.parse_as_flags(flags, self.targetoptions) | |
# Generate expression | |
f = self.py_func(*args) | |
cres = compiler.compile_extra(self.typingctx, self.targetctx, f, | |
args=args, return_type=return_type, | |
flags=flags, locals=self.locals) | |
if cres.typing_error is not None and not flags.enable_pyobject: | |
raise cres.typing_error | |
self.add_overload(cres) | |
self._cache.save_overload(sig, cres) | |
return cres.entry_point | |
def __reduce__(self): | |
if self._can_compile: | |
sigs = [] | |
else: | |
sigs = [cr.signature for cr in self._compileinfos.values()] | |
return (_rebuild, (self.__class__, dumps(self.py_func, recurse=True), | |
self.locals, self.targetoptions, self._can_compile, sigs)) | |
def _rebuild(cls, func_reduced, locals, targetoptions, can_compile, sigs): | |
"""Rebuild a Dispatcher instance after it was __reduce__'d.""" | |
py_func = loads(func_reduced) | |
self = cls(py_func, locals, targetoptions) | |
for sig in sigs: | |
self.compile(sig) | |
self._can_compile = can_compile | |
return self | |
class GeneratedCPUDispatcher(GeneratedFunction): | |
targetdescr = CPUTarget() | |
def generated(*args, **kwargs): | |
kwargs['target'] = 'gen_' + kwargs.get('target', 'cpu') | |
return jit(*args, **kwargs) | |
dispatcher_registry['gen_cpu'] = GeneratedCPUDispatcher |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment