Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
"""Run with python -c 'import pyximport; pyximport.install(); import cellbench; cellbench.main()'
"""
from libc.stdint cimport uint32_t
from libc.math cimport sqrt, modf
from libc.math cimport round as c_round
ctypedef uint32_t Label
cdef inline size_t cellidx(short start, short end, short lensent,
Label nonterminals):
"""Return an index to a triangular array, given start < end.
The result of this function is the index to chart[start][end][0]."""
return nonterminals * (lensent * start
- ((start - 1) * start // 2) + end - start - 1)
cdef inline short cellstart(size_t cell, short lensent,
Label nonterminals):
"""Retrieve start position for a given chart cell."""
cell = cell // nonterminals
cdef short start = 0, idx = 0
while idx + lensent <= cell:
idx += lensent
lensent -= 1
start += 1
return start
# alternative implementation in closed form (slower)
# requires: from libc.math cimport sqrt
cdef inline short cellstart_c(size_t cell, short lensent,
Label nonterminals):
"""Retrieve start position for a given chart cell."""
return int(lensent + 0.5 \
- sqrt(0.25 + lensent * (lensent + 1) - 2 * (cell // nonterminals)))
cdef inline short cellend(size_t cell, short lensent,
Label nonterminals):
"""Retrieve end position for a given chart cell."""
cell = cell // nonterminals
cdef short start = 0, idx = 0
while idx + lensent <= cell:
idx += lensent
lensent -= 1
start += 1
return start + (cell - idx) + 1
# alternative implementation in closed form (slower)
# requires: from libc.math cimport sqrt, modf
cdef inline short cellend_c(size_t cell, short lensent,
Label nonterminals):
"""Retrieve end position for a given chart cell."""
cdef double fractional, start
cdef short tmp
fractional = modf(lensent + 0.5 - sqrt(
0.25 + lensent * (lensent + 1) - 2 * (cell // nonterminals)),
&start)
return <short>start + (<short>c_round((lensent - start) * fractional)) + 1
def a(data): return sum(cellstart(x, 40, 10000) for x in data)
def b(data): return sum(cellstart_c(x, 40, 10000) for x in data)
def c(data): return sum(cellend(x, 40, 10000) for x in data)
def d(data): return sum(cellend_c(x, 40, 10000) for x in data)
def main():
import random
import timeit
data = [random.randint(0, cellidx(0, 40, 40, 10000)) for _ in range(10000)]
if a(data) != b(data): raise ValueError
if c(data) != d(data): raise ValueError
print(timeit.timeit('a(data)', number=1000, globals=dict(data=data, a=a)))
print(timeit.timeit('b(data)', number=1000, globals=dict(data=data, b=b)))
print(timeit.timeit('c(data)', number=1000, globals=dict(data=data, c=c)))
print(timeit.timeit('d(data)', number=1000, globals=dict(data=data, d=d)))
@kilian-gebhardt
Copy link

kilian-gebhardt commented Sep 14, 2018

You actually want line 70 to be data = [random.randint(0, cellidx(39, 40, 40, 10000)) for _ in range(10000)], otherwise you'll always get a cell starting at 0. Even when switching to isqrt from Wikipedia in cellstart_c and increasing sentence length to 80 I don't see any improvements over the iterative version.

return int(lensent - isqrt(lensent * (lensent + 1) - 2 * (cell // nonterminals)))

cdef inline short isqrt(short num):
    cdef:
        short res = 0
        short bit = 1 << 14 # The second-to-top bit is set: 1 << 30 for 32 bits

    # "bit" starts at the highest power of four <= the argument.

    while bit > num:
        bit >>= 2

    while bit != 0:
        if num >= res + bit:
            num -= res + bit
            res += bit << 1

        res >>= 1
        bit >>= 2

    return res

Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment