Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Created July 19, 2013 18:38
Show Gist options
  • Save eltjpm/6041381 to your computer and use it in GitHub Desktop.
Save eltjpm/6041381 to your computer and use it in GitHub Desktop.
Index: numba/environment.py
===================================================================
--- numba/environment.py (revision 80669)
+++ numba/environment.py (revision 80670)
@@ -41,6 +41,7 @@
'update_signature',
'create_lfunc1',
'NormalizeASTStage',
+ 'TransformBuiltinLoops',
'ControlFlowAnalysis',
#'ConstFolding',
'TypeInfer',
@@ -74,6 +75,7 @@
default_type_infer_pipeline_order = [
'ast3to2',
+ 'TransformBuiltinLoops',
'ControlFlowAnalysis',
'TypeInfer',
]
Index: numba/control_flow/control_flow.py
===================================================================
--- numba/control_flow/control_flow.py (revision 80669)
+++ numba/control_flow/control_flow.py (revision 80670)
@@ -773,8 +773,8 @@
warn_unused=warn_unused)
# TODO: Generate fake RHS for for iteration target variable
- elif (isinstance(lhs, ast.Attribute) and self.flow.block and
- assignment is not None):
+ elif (isinstance(lhs, (ast.Attribute, nodes.TempStoreNode)) and
+ self.flow.block and assignment is not None):
self.flow.block.stats.append(AttributeAssignment(assignment))
if self.flow.exceptions:
Index: numba/nodes/tempnodes.py
===================================================================
--- numba/nodes/tempnodes.py (revision 80669)
+++ numba/nodes/tempnodes.py (revision 80670)
@@ -50,7 +50,7 @@
def __init__(self, temp, invariant=False):
self.temp = temp
self.type = temp.type
- self.variable = Variable(self.type)
+ self.variable = temp.variable
self.invariant = invariant
def __repr__(self):
Index: numba/pipeline.py
===================================================================
--- numba/pipeline.py (revision 80669)
+++ numba/pipeline.py (revision 80670)
@@ -402,6 +402,12 @@
env)
return transform.visit(ast)
+class TransformBuiltinLoops(PipelineStage):
+ def transform(self, ast, env):
+ transform = self.make_specializer(loops.TransformBuiltinLoops, ast,
+ env)
+ return transform.visit(ast)
+
#----------------------------------------------------------------------------
# Specializing/Lowering Transforms
#----------------------------------------------------------------------------
Index: numba/specialize/loops.py
===================================================================
--- numba/specialize/loops.py (revision 79617)
+++ numba/specialize/loops.py (revision 83202)
@@ -2,6 +2,10 @@
from __future__ import print_function, division, absolute_import
import ast
import textwrap
+try:
+ import __builtin__ as builtins
+except ImportError:
+ import builtins
import numba
from numba import *
@@ -57,6 +61,12 @@
while_node = nodes.build_while(**vars(while_node))
return while_node
+def untypedTemp():
+ "Temp node with a yet unknown type"
+ type = typesystem.DeferredType(None)
+ temp = nodes.TempNode(type)
+ type.variable = temp.variable
+ return temp
#------------------------------------------------------------------------
# Transform for loops
@@ -261,6 +265,145 @@
return node
#------------------------------------------------------------------------
+# Transform for loops over builtins
+#------------------------------------------------------------------------
+
+class TransformBuiltinLoops(visitors.NumbaTransformer):
+ def rewrite_enumerate(self, node):
+ """
+ Rewrite a loop like
+
+ for i, x in enumerate(array[, start]):
+ ...
+
+ into
+
+ _arr = array
+ [_s = start]
+ for _i in range(len(_arr)):
+ i = _i [+ _s]
+ x = _arr[_i]
+ ...
+ """
+ call = node.iter
+ if (len(call.args) not in (1, 2) or call.keywords or
+ call.starargs or call.kwargs):
+ self.error(call, 'expected 1 or 2 arguments to enumerate()')
+
+ target = node.target
+ if (not isinstance(target, (ast.Tuple, ast.List)) or
+ len(target.elts) != 2):
+ self.error(call, 'expected 2 iteration variables')
+
+ array = call.args[0]
+ start = call.args[1] if len(call.args) > 1 else None
+ idx = target.elts[0]
+ var = target.elts[1]
+
+ array_temp = untypedTemp()
+ if start:
+ start_temp = untypedTemp() # TODO: only allow integer start
+ idx_temp = nodes.TempNode(typesystem.Py_ssize_t)
+
+ # for _i in range(len(_arr)):
+ node.target = idx_temp.store()
+ node.iter = ast.Call(ast.Name('range', ast.Load()),
+ [ast.Call(ast.Name('len', ast.Load()),
+ [array_temp.load(True)],
+ [], None, None)],
+ [], None, None)
+
+ # i = _i [+ _s]
+ new_idx = idx_temp.load()
+ if start:
+ new_idx = ast.BinOp(new_idx, ast.Add(), start_temp.load(True))
+ node.body.insert(0, ast.Assign([idx], new_idx))
+
+ # x = _arr[_i]
+ value = ast.Subscript(array_temp.load(True),
+ ast.Index(idx_temp.load()),
+ ast.Load())
+ node.body.insert(1, ast.Assign([var], value))
+
+ # _arr = array; [_s = start]; ...
+ body = [ ast.Assign([array_temp.store()], array), node ]
+ if start:
+ body.insert(1, ast.Assign([start_temp.store()], start))
+ return map(self.visit, body)
+
+ def rewrite_zip(self, node):
+ """
+ Rewrite a loop like
+
+ for x, y... in zip(xs, ys...):
+ ...
+
+ into
+
+ _xs = xs; _ys = ys...
+ for _i in range(min(len(_xs), len(_ys)...)):
+ x = _xs[_i]; y = _ys[_i]...
+ ...
+ """
+ call = node.iter
+ if not call.args or call.keywords or call.starargs or call.kwargs:
+ self.error(call, 'expected at least 1 argument to zip()')
+
+ target = node.target
+ if (not isinstance(target, (ast.Tuple, ast.List)) or
+ len(target.elts) != len(call.args)):
+ self.error(call, 'expected %d iteration variables' % len(call.args))
+
+ temps = [untypedTemp() for _ in xrange(len(call.args))]
+ idx_temp = nodes.TempNode(typesystem.Py_ssize_t)
+
+ # min(len(_xs), len(_ys)...)
+ len_call = ast.Call(ast.Name('min', ast.Load()),
+ [ast.Call(ast.Name('len', ast.Load()),
+ [tmp.load(True)], [], None, None)
+ for tmp in temps],
+ [], None, None)
+
+ # for _i in range(...):
+ node.target = idx_temp.store()
+ node.iter = ast.Call(ast.Name('range', ast.Load()),
+ [len_call], [], None, None)
+
+ # x = _xs[_i]; y = _ys[_i]...
+ node.body = [ast.Assign([tgt],
+ ast.Subscript(tmp.load(True),
+ ast.Index(idx_temp.load()),
+ ast.Load()))
+ for tgt, tmp in zip(target.elts, temps)] + \
+ node.body
+
+ # _xs = xs; _ys = ys...
+ body = [ast.Assign([tmp.store()], arg)
+ for tmp, arg in zip(temps, call.args)] + \
+ [node]
+ return map(self.visit, body)
+
+ HANDLERS = {
+ id(enumerate): rewrite_enumerate,
+ id(zip): rewrite_zip,
+ }
+
+ def visit_For(self, node):
+ if (isinstance(node.iter, ast.Call) and
+ isinstance(node.iter.func, ast.Name)):
+ name = node.iter.func.id
+ if name not in self.symtab:
+ obj = (self.func_globals[name]
+ if name in self.func_globals else
+ getattr(builtins, name, None))
+ rewriter = self.HANDLERS.get(id(obj))
+ if rewriter:
+ return rewriter(self, node)
+
+ self.visitchildren(node)
+ return node
+
+#------------------------------------------------------------------------
# Transform for loops over Objects
#------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment