Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Last active December 20, 2015 02:39
Show Gist options
  • Save eltjpm/6058318 to your computer and use it in GitHub Desktop.
Save eltjpm/6058318 to your computer and use it in GitHub Desktop.
Index: athena/src/numba/numba/visitors.py
===================================================================
--- athena/src/numba/numba/visitors.py (revision 81084)
+++ athena/src/numba/numba/visitors.py (revision 83292)
@@ -27,6 +27,17 @@
def __init__(self, *args, **kwargs):
pass
+def _flatmap(func, sequence):
+ result = []
+ for elem in sequence:
+ res = func(elem)
+ if res is not None:
+ if isinstance(res, list):
+ result.extend(res)
+ else:
+ result.append(res)
+ return result
+
class NumbaVisitorMixin(CooperativeBase):
_overloads = None
@@ -272,13 +283,7 @@
return self.have(v1.type, v2.type, p1, p2)
def visitlist(self, list):
- newlist = []
- for node in list:
- result = self.visit(node)
- if result is not None:
- newlist.append(result)
-
- list[:] = newlist
+ list[:] = _flatmap(self.visit, list)
return list
def is_complex(self, n):
@@ -337,7 +342,7 @@
"Non-mutating visitor"
def visitlist(self, list):
- return [self.visit(item) for item in list]
+ return _flatmap(self.visit, list)
class NumbaTransformer(NumbaVisitorMixin, ast.NodeTransformer):
"Mutating visitor"
Index: athena/src/numba/numba/codegen/translate.py
===================================================================
--- athena/src/numba/numba/codegen/translate.py (revision 81084)
+++ athena/src/numba/numba/codegen/translate.py (revision 81085)
@@ -698,11 +698,8 @@
self.builder.position_at_end(node.entry_block)
self._init_phis(node)
- if len(node.body) == 1:
- lbody = self.visit(node.body[0])
- else:
- self.visitlist(node.body)
- lbody = None
+ lbody = self.visitlist(node.body)
+ lbody = lbody[0] if len(lbody) == 1 else None
if not node.exit_block:
node.exit_block = self.builder.basic_block
from numba import int_, jit
import numpy
@jit(argtypes=(int_[:],))
def test1(arr):
u = 0
for x in arr:
u += len(arr)
v = 0
for y in arr:
v += len(arr)
return u + v
@jit(argtypes=(int_[:],))
def test2(arr):
s = 0
for i, x in enumerate(arr):
s += i*x
s2 = 0
for i2, x2 in enumerate(arr, 1):
s2 += i2*x2
return s+s2
if __name__ == '__main__':
arr = numpy.arange(1, 4)
assert test1(arr) == test1.py_func(arr)
assert test2(arr) == test2.py_func(arr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment