Skip to content

Instantly share code, notes, and snippets.

@Dapid
Created March 17, 2015 00:23
Show Gist options
  • Save Dapid/3acda96119f95fd7bdab to your computer and use it in GitHub Desktop.
Save Dapid/3acda96119f95fd7bdab to your computer and use it in GitHub Desktop.
Probability trees
import numpy as np
cimport numpy as np
cimport cython
import random
from math import ceil
from libc.math cimport log2
@cython.wraparound(False)
@cython.boundscheck(False)
def PTCreate(double[::1] L):
cdef int n = L.shape[0]
cdef int m = ceil(log2(n))
cdef double[:, :] PT = np.zeros([m+1,2**m])
cdef int i, k
with nogil:
for i in xrange(n):
PT[0,i] = L[i]
for k in xrange(1,m+1):
for i in xrange(2**(m-k)):
PT[k,i] = PT[k-1,2*i] + PT[k-1,2*i+1]
return PT
@cython.wraparound(False)
@cython.boundscheck(False)
cpdef int PTSample(double[:, :] PT):
cdef int m = PT.shape[0]
cdef int i = 0
cdef int k
cdef double rn
for k in xrange(m,0,-1):
rn = random.random()
i = 2*i + (PT[k,i]*rn < PT[k-1,2*i])
#if PT[k,i]*rn < PT[k-1,2*i]:
# i = 2*i
#else:
# i = 2*i +1
return i
@cython.wraparound(False)
@cython.boundscheck(False)
def PTSample_many(double[:, :] PT, int N):
cdef double[::1] out = np.empty(N)
cdef int i
for i in xrange(N):
out[i] = PTSample(PT)
return out
@cython.wraparound(False)
@cython.boundscheck(False)
def PTUpdate(double[:, :] PT, int ind, double newValue):
cdef int k, m = PT.shape[0]
PT[0,ind] = newValue
with nogil:
for k in xrange(1, m):
ind = <int> (ind/2)
PT[k,ind] = PT[k-1,2*ind] + PT[k-1,2*ind + 1]
return PT
@cython.boundscheck(False)
def PTSample_many(double[:, :] PT, int N):
cdef double[::1] out = np.empty(N)
cdef int i
for i in xrange(N):
out[i] = PTSample(PT)
return out
@cython.boundscheck(False)
def PTUpdate(double[:, :] PT, int ind, double newValue):
cdef int k, m = PT.shape[0]
PT[0,ind] = newValue
with nogil:
for k in xrange(1, m):
ind = <int> (ind/2)
PT[k,ind] = PT[k-1,2*ind] + PT[k-1,2*ind + 1]
return PT
import os
import numpy
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
from Cython.Build import cythonize
module = 'prob_trees'
ext = Extension(module, [module + '.pyx'],
include_dirs=[numpy.get_include()],
extra_compile_args=['-O2', '-march=native', '-mtune=native',
'-funroll-loops', '-fpic', '-flto'])
setup(name=module, ext_modules=cythonize(ext))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment