Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Created July 23, 2013 01:00
Show Gist options
  • Save eltjpm/6059053 to your computer and use it in GitHub Desktop.
Save eltjpm/6059053 to your computer and use it in GitHub Desktop.
Index: numba/transforms.py
===================================================================
--- numba/transforms.py (revision 83292)
+++ numba/transforms.py (working copy)
@@ -630,17 +635,16 @@
def visit_MathNode(self, math_node):
"Translate a nodes.MathNode to an intrinsic or libc math call"
- from numba.type_inference.modules import mathmodule
- lowerable = is_math_function([math_node.arg], math_node.py_func)
+ lowerable = is_math_function(math_node.args, math_node.py_func)
if math_node.type.is_array or not lowerable:
# Generate a Python call
assert math_node.py_func is not None
- result = nodes.call_pyfunc(math_node.py_func, [math_node.arg])
+ result = nodes.call_pyfunc(math_node.py_func, math_node.args)
result = result.coerce(math_node.type)
else:
# Lower to intrinsic or libc math call
- args = [math_node.arg], math_node.py_func, math_node.type
+ args = math_node.args, math_node.py_func, math_node.type
if is_intrinsic(math_node.py_func):
result = resolve_intrinsic(*args)
else:
Index: numba/type_inference/modules/mathmodule.py
===================================================================
--- numba/type_inference/modules/mathmodule.py (revision 83292)
+++ numba/type_inference/modules/mathmodule.py (working copy)
@@ -10,11 +10,6 @@
import math
import cmath
-try:
- import __builtin__ as builtins
-except ImportError:
- import builtins
-
import numpy as np
from numba import *
@@ -30,11 +25,15 @@
register_math_typefunc = utils.register_with_argchecking
-def binop_type(context, x, y):
- "Binary result type for math operations"
- x_type = get_type(x)
- y_type = get_type(y)
- return context.promote_types(x_type, y_type)
+def largest_type(default_result_type, types):
+ for type in types:
+#TODO: put back array support
+# if type.is_array and type.dtype.is_int:
+# type = type.copy(dtype=double)
+ if type.is_numeric and type.kind > default_result_type.kind:
+ default_result_type = type
+
+ return default_result_type
#----------------------------------------------------------------------------
# Determine math functions
@@ -42,39 +41,49 @@
# sin(double), sinf(float), sinl(long double)
unary_libc_math_funcs = [
- 'sin',
- 'cos',
- 'tan',
- 'sqrt',
'acos',
+ 'acosh',
'asin',
- 'atan',
- 'atan2',
- 'sinh',
- 'cosh',
- 'tanh',
'asinh',
- 'acosh',
+ 'atan',
'atanh',
- 'log',
- 'log2',
- 'log10',
- 'fabs',
- 'erfc',
'ceil',
+ 'cos',
+ 'cosh',
+ 'erfc',
'exp',
'exp2',
'expm1',
- 'rint',
+ 'fabs',
+ #factorial
+ 'floor',
+ #'isinf', # linux only -- returns bool
+ #'isnan', # -- returns bool
+ 'log',
+ 'log10',
'log1p',
+ 'log2',
+ #radians
+ 'rint',
+ 'round', # linux only
+ 'sin',
+ 'sinh',
+ 'sqrt',
+ 'tan',
+ 'tanh',
+ #'trunc', # linux only -- returns int in python
]
-n_ary_libc_math_funcs = [
+binary_libc_math_funcs = [
+ 'atan2',
+ 'copysign',
+ 'fmod',
+ 'hypot',
+ #ldexp -- int argument
'pow',
- 'round',
]
-all_libc_math_funcs = unary_libc_math_funcs + n_ary_libc_math_funcs
+all_libc_math_funcs = unary_libc_math_funcs + binary_libc_math_funcs
#----------------------------------------------------------------------------
# Math Type Inferers
@@ -82,44 +91,27 @@
# TODO: Move any rewriting parts to lowering phases
-def infer_unary_math_call(context, call_node, arg, default_result_type=double):
+def infer_math_or_cmath_call(default_result_type, context, call_node, *args):
"Resolve calls to math functions to llvm.log.f32() etc"
# signature is a generic signature, build a correct one
- type = get_type(call_node.args[0])
-
- if type.is_numeric and type.kind < default_result_type.kind:
- type = default_result_type
- elif type.is_array and type.dtype.is_int:
- type = type.copy(dtype=double)
-
- # signature = minitypes.FunctionType(return_type=type, args=[type])
- # result = nodes.MathNode(py_func, signature, call_node.args[0])
+ type = largest_type(default_result_type, map(get_type, args))
nodes.annotate(context.env, call_node, is_math=True)
call_node.variable = Variable(type)
return call_node
-def infer_unary_cmath_call(context, call_node, arg):
- result = infer_unary_math_call(context, call_node, arg,
- default_result_type=complex128)
+def infer_math_call(context, call_node, *arg):
+ return infer_math_or_cmath_call(double, context, call_node, *arg)
+
+def infer_cmath_call(context, call_node, *arg):
+ result = infer_math_or_cmath_call(complex128, context, call_node, *arg)
nodes.annotate(context.env, call_node, is_cmath=True)
return result
# ______________________________________________________________________
-# pow()
-
-def pow_(context, call_node, node, power, mod=None):
- dst_type = binop_type(context, node, power)
- call_node.variable = Variable(dst_type)
- return call_node
-
-register_math_typefunc((2, 3), math.pow)
-register_math_typefunc(2, np.power)
-
-# ______________________________________________________________________
-# abs()
+# broken numpy funcs
def abs_(context, node, x):
- import builtinmodule
+ from . import builtinmodule
argtype = get_type(x)
@@ -132,7 +124,9 @@
return builtinmodule.abs_(context, node, x)
-register_math_typefunc(1, np.abs)
+#FIXME: this one is broken, core.issues.test_issue_56 fails
+#register_math_typefunc(1)(abs_, np.abs)
+#register_math_typefunc(2)(infer_binary_math_call, np.power)
#----------------------------------------------------------------------------
# Register Type Functions
@@ -140,23 +134,25 @@
def register_math(nargs, value):
register = register_math_typefunc(nargs)
- register(infer_unary_math_call, value)
+ register(infer_math_call, value)
def register_cmath(nargs, value):
register = register_math_typefunc(nargs)
- register(infer_unary_cmath_call, value)
+ register(infer_cmath_call, value)
def register_typefuncs():
- modules = [builtins, math, cmath, np]
- # print all_libc_math_funcs
- for libc_math_func in unary_libc_math_funcs:
- for module in modules:
- if hasattr(module, libc_math_func):
- if module is cmath:
- register = register_cmath
- else:
- register = register_math
+ modules = [math, cmath, np]
+ for nargs, libc_math_funcs in [(1, unary_libc_math_funcs),
+ (2, binary_libc_math_funcs)]:
+ for libc_math_func in libc_math_funcs:
+ for module in modules:
+ if hasattr(module, libc_math_func):
+ if module is cmath:
+ register = register_cmath
+ else:
+ register = register_math
- register(1, getattr(module, libc_math_func))
+ register(nargs, getattr(module, libc_math_func))
register_typefuncs()
+
Index: numba/specialize/mathcalls.py
===================================================================
--- numba/specialize/mathcalls.py (revision 83292)
+++ numba/specialize/mathcalls.py (working copy)
@@ -38,11 +38,20 @@
is_intrinsic = hasattr(llvm.core, intrinsic_name)
return is_intrinsic
+if is_win32:
+ _MAPPING = {
+ 'abs' : 'fabs',
+ 'hypot' : '_hypot',
+ 'isnan' : '_isnan',
+ 'copysign' : '_copysign',
+ }
+else:
+ _MAPPING = {
+ 'abs' : 'fabs',
+ }
def math_suffix(name, type):
- if name == 'abs':
- name = 'fabs'
-
+ name = _MAPPING.get(name, name)
if type.is_float and type.itemsize == 4:
name += 'f' # sinf(float)
elif type.is_int and type.itemsize == 16:
@@ -61,11 +70,10 @@
return math_suffix(math_name, double) in libc_math_funcs
def is_math_function(func_args, py_func):
- if len(func_args) == 0 or len(func_args) > 1 or py_func is None:
+ if not func_args or py_func is None:
return False
- type = get_type(func_args[0])
-
+ type = mathmodule.largest_type(float_, map(get_type, func_args))
if type.is_array:
type = type.dtype
valid_type = type.is_float or type.is_int or type.is_complex
@@ -104,8 +112,8 @@
def resolve_math_call(call_node, py_func):
"Resolve calls to math functions to llvm.log.f32() etc"
- signature = call_node.type(call_node.type)
- return nodes.MathNode(py_func, signature, call_node.args[0])
+ signature = call_node.type(*[call_node.type] * len(call_node.args))
+ return nodes.MathNode(py_func, signature, call_node.args)
def filter_math_funcs(math_func_names):
if is_win32:
@@ -115,7 +123,8 @@
result_func_names = []
for name in math_func_names:
- if getattr(dll, name, None) is not None:
+ cname = _MAPPING.get(name, name)
+ if getattr(dll, cname, None) is not None:
result_func_names.append(name)
return result_func_names
Index: numba/nodes/callnodes.py
===================================================================
--- numba/nodes/callnodes.py (revision 83292)
+++ numba/nodes/callnodes.py (working copy)
@@ -87,13 +87,13 @@
Represents a high-level call to a math function.
"""
- _fields = ['arg']
+ _fields = ['args']
- def __init__(self, py_func, signature, arg, **kwargs):
+ def __init__(self, py_func, signature, args, **kwargs):
super(MathNode, self).__init__(**kwargs)
self.py_func = py_func
self.signature = signature
- self.arg = arg
+ self.args = args
self.type = signature.return_type
class LLVMExternalFunctionNode(ExprNode):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment