Last active
December 20, 2015 02:39
-
-
Save eltjpm/6058318 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: athena/src/numba/numba/visitors.py | |
=================================================================== | |
--- athena/src/numba/numba/visitors.py (revision 81084) | |
+++ athena/src/numba/numba/visitors.py (revision 83292) | |
@@ -27,6 +27,17 @@ | |
def __init__(self, *args, **kwargs): | |
pass | |
+def _flatmap(func, sequence): | |
+ result = [] | |
+ for elem in sequence: | |
+ res = func(elem) | |
+ if res is not None: | |
+ if isinstance(res, list): | |
+ result.extend(res) | |
+ else: | |
+ result.append(res) | |
+ return result | |
+ | |
class NumbaVisitorMixin(CooperativeBase): | |
_overloads = None | |
@@ -272,13 +283,7 @@ | |
return self.have(v1.type, v2.type, p1, p2) | |
def visitlist(self, list): | |
- newlist = [] | |
- for node in list: | |
- result = self.visit(node) | |
- if result is not None: | |
- newlist.append(result) | |
- | |
- list[:] = newlist | |
+ list[:] = _flatmap(self.visit, list) | |
return list | |
def is_complex(self, n): | |
@@ -337,7 +342,7 @@ | |
"Non-mutating visitor" | |
def visitlist(self, list): | |
- return [self.visit(item) for item in list] | |
+ return _flatmap(self.visit, list) | |
class NumbaTransformer(NumbaVisitorMixin, ast.NodeTransformer): | |
"Mutating visitor" | |
Index: athena/src/numba/numba/codegen/translate.py | |
=================================================================== | |
--- athena/src/numba/numba/codegen/translate.py (revision 81084) | |
+++ athena/src/numba/numba/codegen/translate.py (revision 81085) | |
@@ -698,11 +698,8 @@ | |
self.builder.position_at_end(node.entry_block) | |
self._init_phis(node) | |
- if len(node.body) == 1: | |
- lbody = self.visit(node.body[0]) | |
- else: | |
- self.visitlist(node.body) | |
- lbody = None | |
+ lbody = self.visitlist(node.body) | |
+ lbody = lbody[0] if len(lbody) == 1 else None | |
if not node.exit_block: | |
node.exit_block = self.builder.basic_block |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import int_, jit | |
import numpy | |
@jit(argtypes=(int_[:],)) | |
def test1(arr): | |
u = 0 | |
for x in arr: | |
u += len(arr) | |
v = 0 | |
for y in arr: | |
v += len(arr) | |
return u + v | |
@jit(argtypes=(int_[:],)) | |
def test2(arr): | |
s = 0 | |
for i, x in enumerate(arr): | |
s += i*x | |
s2 = 0 | |
for i2, x2 in enumerate(arr, 1): | |
s2 += i2*x2 | |
return s+s2 | |
if __name__ == '__main__': | |
arr = numpy.arange(1, 4) | |
assert test1(arr) == test1.py_func(arr) | |
assert test2(arr) == test2.py_func(arr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment