Skip to content

Instantly share code, notes, and snippets.

@eltjpm
Created July 19, 2013 18:44
Show Gist options
  • Save eltjpm/6041422 to your computer and use it in GitHub Desktop.
Save eltjpm/6041422 to your computer and use it in GitHub Desktop.
Index: numba/specialize/loops.py
===================================================================
--- numba/specialize/loops.py (revision 79617)
+++ numba/specialize/loops.py (revision 83202)
@@ -95,36 +105,31 @@
else:
have_step = True
- start, stop, step = [nodes.CloneableNode(n)
- for n in (start, stop, step)]
+ start, stop, step = map(nodes.CloneableNode, (start, stop, step))
if have_step:
- compute_nsteps = """
- $length = {{stop}} - {{start}}
- {{nsteps}} = $length / {{step}}
- if {{nsteps_load}} * {{step}} != $length: #$length % {{step}}:
- # Test for truncation
- {{nsteps}} = {{nsteps_load}} + 1
- # print "nsteps", {{nsteps_load}}
- """
+ templ = textwrap.dedent("""
+ {{temp}} = 0
+ {{nsteps}} = ({{stop}} - {{start}} + {{step}} -
+ (1 if {{step}} >= 0 else -1)) / {{step}}
+ while {{temp_load}} < {{nsteps_load}}:
+ {{target}} = {{start}} + {{temp_load}} * {{step}}
+ {{body}}
+ {{temp}} = {{temp_load}} + 1
+ """)
else:
- compute_nsteps = "{{nsteps}} = {{stop}} - {{start}}"
+ templ = textwrap.dedent("""
+ {{temp}} = {{start}}
+ {{nsteps}} = {{stop}}
+ while {{temp_load}} < {{nsteps_load}}:
+ {{target}} = {{temp_load}}
+ {{body}}
+ {{temp}} = {{temp_load}} + 1
+ """)
if node.orelse:
- else_clause = "else: {{else_body}}"
- else:
- else_clause = ""
+ templ += "\nelse: {{else_body}}"
- templ = textwrap.dedent("""
- %s
- {{temp}} = 0
- while {{temp_load}} < {{nsteps_load}}:
- {{target}} = {{start}} + {{temp_load}} * {{step}}
- {{body}}
- {{temp}} = {{temp_load}} + 1
- %s
- """) % (textwrap.dedent(compute_nsteps), else_clause)
-
# Leave the bodies empty, they are already analyzed
body = ast.Suite(body=[])
else_body = ast.Suite(body=[])
@@ -196,8 +201,7 @@
# Replace node.target with a temporary
#--------------------------------------------------------------------
- target_name = orig_target.id + '.idx'
- target_temp = nodes.TempNode(Py_ssize_t)
+ target_temp = nodes.TempNode(typesystem.Py_ssize_t)
node.target = target_temp.store()
#--------------------------------------------------------------------
from numba import autojit
import numpy as np
import unittest
@autojit
def for_loop_fn_1 (start, stop, inc):
acc = 0
for value in range(start, stop, inc):
acc += value
return acc
@autojit
def for_loop_fn_1a (start, stop):
acc = 0
for value in range(start, stop):
acc += value
return acc
@autojit
def for_loop_fn_1b (stop):
acc = 0
for value in range(stop):
acc += value
return acc
class TestForLoop(unittest.TestCase):
def test_compiled_for_loop_fn_many(self):
for lo in xrange( -10, 11 ):
for hi in xrange( -10, 11 ):
for step in xrange( -20, 21 ):
if step:
self.assertEqual(for_loop_fn_1(lo, hi, step),
for_loop_fn_1.py_func(lo, hi, step),
'failed for %d/%d/%d' % (lo, hi, step))
self.assertEqual(for_loop_fn_1a(lo, hi),
for_loop_fn_1a.py_func(lo, hi),
'failed for %d/%d' % (lo, hi))
self.assertEqual(for_loop_fn_1b(hi),
for_loop_fn_1b.py_func(hi),
'failed for %d' % hi)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment