Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Last active December 20, 2015 00:19
Show Gist options
  • Save eltjpm/6041224 to your computer and use it in GitHub Desktop.
Save eltjpm/6041224 to your computer and use it in GitHub Desktop.
Index: builtinmodule.py
===================================================================
--- builtinmodule.py (revision 79617)
+++ builtinmodule.py (revision 83202)
@@ -4,11 +4,11 @@
"""
from __future__ import print_function, division, absolute_import
+import ast
from numba import *
from numba import nodes
from numba import error
# from numba import function_util
-# from numba.specialize.mathcalls import is_math_function
from numba.symtab import Variable
from numba import typesystem
from numba.typesystem import is_obj, promote_closest, get_type
@@ -27,6 +27,9 @@
else:
return nodes.CoercionNode(node.args[0], dst_type=dst_type)
+def binop_type(context, x, y):
+ return context.promote_types(get_type(x), get_type(y))
+
#----------------------------------------------------------------------------
# Type Functions for Builtins
#----------------------------------------------------------------------------
@@ -75,6 +78,10 @@
def _float(context, node, x):
return cast(node, double)
+@register_builtin((0, 1), can_handle_deferred_types=True)
+def _bool(context, node, x):
+ return cast(node, bool_)
+
@register_builtin((0, 1, 2), can_handle_deferred_types=True)
def complex_(context, node, a, b):
if len(node.args) == 2:
@@ -100,12 +107,11 @@
@register_builtin((2, 3))
def pow_(context, node, base, exponent, mod):
- from . import mathmodule
- return mathmodule.pow_(context, node, base, exponent)
+ node.variable = Variable(binop_type(context, base, exponent))
+ return node
@register_builtin((1, 2))
def round_(context, node, number, ndigits):
- # is_math = is_math_function(node.args, round)
argtype = get_type(number)
if len(node.args) == 1 and argtype.is_int:
@@ -121,6 +127,43 @@
node.variable = Variable(dst_type)
return node # nodes.CoercionNode(node, double)
+def minmax(context, args, op):
+ if len(args) < 2:
+ return
+
+ res = args[0]
+ for arg in args[1:]:
+ lhs_type = get_type(res)
+ rhs_type = get_type(arg)
+ res_type = context.promote_types(lhs_type, rhs_type)
+ if lhs_type != res_type:
+ res = nodes.CoercionNode(res, res_type)
+ if rhs_type != res_type:
+ arg = nodes.CoercionNode(arg, res_type)
+
+ lhs_temp = nodes.TempNode(res_type)
+ rhs_temp = nodes.TempNode(res_type)
+ res_temp = nodes.TempNode(res_type)
+ lhs = lhs_temp.load(invariant=True)
+ rhs = rhs_temp.load(invariant=True)
+ expr = ast.IfExp(ast.Compare(lhs, [op], [rhs]), lhs, rhs)
+ body = [
+ ast.Assign([lhs_temp.store()], res),
+ ast.Assign([rhs_temp.store()], arg),
+ ast.Assign([res_temp.store()], expr),
+ ]
+ res = nodes.ExpressionNode(body, res_temp.load(invariant=True))
+
+ return res
+
+@register_builtin(None)
+def min_(context, node, *args):
+ return minmax(context, args, ast.Lt())
+
+@register_builtin(None)
+def max_(context, node, *args):
+ return minmax(context, args, ast.Gt())
+
@register_builtin(0)
def globals_(context, node):
return typesystem.dict_
from numba import autojit
import lib.numbatest
@autojit
def max1(x):
"""
>>> max1([100])
100
>>> max1([1,2.0,3])
3
>>> max1([-1,-2,-3.0])
-1
>>> max1(1)
Traceback (most recent call last):
...
TypeError: 'int' object is not iterable
"""
return max(x)
@autojit
def min1(x):
"""
>>> min1([100])
100
>>> min1([1,2,3.0])
1
>>> min1([-1,-2.0,-3])
-3
>>> min1(1)
Traceback (most recent call last):
...
TypeError: 'int' object is not iterable
"""
return min(x)
@autojit
def max2(x, y):
"""
>>> max2(1, 2)
2
>>> max2(1, -2)
1
>>> max2(10, 10.25)
10.25
>>> max2(10, 9.9)
10.0
>>> max2(0.1, 0.25)
0.25
>>> max2(1, 'a')
Traceback (most recent call last):
...
UnpromotableTypeError: (int, const char *)
"""
return max(x, y)
@autojit
def min2(x, y):
"""
>>> min2(1, 2)
1
>>> min2(1, -2)
-2
>>> min2(10, 10.1)
10.0
>>> min2(10, 9.75)
9.75
>>> min2(0.25, 0.3)
0.25
>>> min2(1, 'a')
Traceback (most recent call last):
...
UnpromotableTypeError: (int, const char *)
"""
return min(x, y)
@autojit
def max4(x):
"""
>>> max4(20)
20.0
"""
return max(1, 2.0, x, 14)
@autojit
def min4(x):
"""
>>> min4(-2)
-2.0
"""
return min(1, 2.0, x, 14)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment