Skip to content

Instantly share code, notes, and snippets.

@aminnj
Last active August 12, 2018 01:21
Show Gist options
  • Save aminnj/0cc86c5519f238b101bf30fa9f2c69ce to your computer and use it in GitHub Desktop.
Save aminnj/0cc86c5519f238b101bf30fa9f2c69ce to your computer and use it in GitHub Desktop.
import sys
import time
import numba
import numpy as np
import awkward
from awkward import JaggedArray
@numba.jit(nopython=True,cache=True)
def numba_min(content,offsets):
result = np.zeros(len(offsets)-1,dtype=content.dtype)
for cursor_off in range(0,len(offsets)-1):
start = offsets[cursor_off]
stop = offsets[cursor_off+1]
best = 2.0e6
for i in range(0,stop-start,1):
val = content[start+i]
if val < best:
best = val
if best < 1.0e6:
result[cursor_off] = best
return result
@numba.jit(nopython=True,cache=True)
def numba_sum(content,offsets):
result = np.zeros(len(offsets)-1,dtype=content.dtype)
for cursor_off in range(0,len(offsets)-1):
start = offsets[cursor_off]
stop = offsets[cursor_off+1]
accum = 0.
for i in range(stop-start):
accum += content[start+i]
result[cursor_off] = accum
return result
def compare(vals,funcs,name):
reps = 5
rates = []
for func in funcs:
t0 = time.time()
for _ in range(reps):
func(vals)
t1 = time.time()
rate = 1.0e-6*vals.shape[0]/(t1-t0)*reps
rates.append(rate)
print("{} -- numba jit: {:.2f}MHz, numpy/awkward: {:.2f}MHz, python: {:.2f}MHz".format(name,rates[0],rates[1],rates[2]))
if __name__ == "__main__":
print("awkward", awkward.__version__)
print("numpy", np.__version__)
print("numba", numba.__version__)
print("python", sys.version_info)
jit_min = lambda x: numba_min(x.content,x.offsets)
jit_sum = lambda x: numba_sum(x.content,x.offsets)
np_min = lambda x: x.min()
np_sum = lambda x: x.sum()
mymin = lambda x: 0. if not len(x) else x.min()
py_min = lambda x: list(map(mymin,x))
py_sum = lambda x: list(map(sum,x))
njets = np.random.randint(0,20, 100000)
content = 5.+100.*np.random.random(njets.sum())
offsets = np.cumsum(njets)
pts = JaggedArray.fromoffsets(offsets,content)
print(pts)
print(pts.shape)
_ = jit_sum(pts)
compare(pts, [jit_sum,np_sum,py_sum], "sum")
compare(pts, [jit_min,np_min,py_min], "min")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment