Created
February 12, 2016 07:13
-
-
Save MatthieuDartiailh/baf60d17fc5d7ac102a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from numba import jitclass, float64, jit | |
from funcsigs import signature | |
from future.utils import exec_ | |
class Integrand(object): | |
pass | |
integrand_template = \ | |
""" | |
class Inte(Integrand): | |
def __init__(self, {args}): | |
{init} | |
def eval_(self, x): | |
return {name}({func_args}) | |
""" | |
def create_integrand(func, types={}, arg_index=0): | |
"""Create an integrand object. | |
This object provides a fast evaluation method (eval_), allows to set | |
other argument values (avoid rebuilding a new object each time). | |
WARNING : | |
This function should not be called in the definition file of the function. | |
Parameters | |
---------- | |
func : jitted function | |
Jitted function to use as integrand. | |
types : dict, optional | |
Dictionary specifying the type of the function variables. By default | |
all arguments are assumed to be float64. Types must be imported from | |
numba. | |
arg_index : int, optional | |
Index of the function argument to use as integration variable. | |
""" | |
t = integrand_template | |
sig = signature(func.py_func) | |
parameters = [a for i, a in enumerate(sig.parameters) if i != arg_index] | |
default_types = {k: float64 for k in parameters} | |
default_types.update(types) | |
args = ', '.join(parameters) | |
init = '' | |
for a in parameters: | |
init += ' self.{} = {}\n'.format(a, a) | |
if not init: | |
init = ' pass' | |
func_args = ', '.join(('self.%s' % a if a in parameters else 'x' | |
for a in sig.parameters)) | |
body = t.format(args=args, init=init.strip(), func_args=func_args, | |
name='func') | |
env = {'func': func, 'Integrand': Integrand} | |
exec_(body, env) | |
cls = env['Inte'] | |
return jitclass(default_types)(cls) | |
if __name__ == '__main__': | |
@jit | |
def f(x, a): | |
return x + a | |
I = create_integrand(f) | |
i = I(1) | |
print(i.eval_(2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment