Last active
August 29, 2015 14:13
-
-
Save shoyer/9d06b7294a8d06981ec7 to your computer and use it in GitHub Desktop.
numba performance with indexing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import guvectorize, jit | |
import numpy as np | |
import pandas as pd | |
@guvectorize(['void(float64[:], int64[:], float64[:], float64[:])'], | |
'(x),(x),(y)->()') | |
def _grouped_sum_guvec_simple(values, labels, target, out): | |
for i in range(len(values)): | |
idx = labels[i] | |
target[idx] += values[i] | |
@guvectorize(['void(float64[:], int64[:], float64[:], float64[:])'], | |
'(x),(x),(y)->()') | |
def _grouped_sum_guvec_conditional(values, labels, target, out): | |
for i in range(len(values)): | |
idx = labels[i] | |
if idx >= 0: | |
target[idx] += values[i] | |
@jit | |
def _grouped_sum_jit_simple(values, labels, target): | |
for i in range(len(values)): | |
idx = labels[i] | |
target[idx] += values[i] | |
@jit | |
def _grouped_sum_jit_conditional(values, labels, target): | |
for i in range(len(values)): | |
idx = labels[i] | |
if idx >= 0: | |
target[idx] += values[i] | |
def call_grouped_sum(agg_func, values, labels, n_unique): | |
target = np.zeros(n_unique, values.dtype) | |
agg_func(values, labels, target) | |
return target | |
values = np.random.randn(int(1e8)) | |
labels = np.random.randint(10, size=int(1e8)) | |
ulabels = labels.astype(np.uint) | |
n_unique = 10 | |
%timeit -r 10 call_grouped_sum(_grouped_sum_jit_simple, values, labels, n_unique) | |
%timeit -r 10 call_grouped_sum(_grouped_sum_jit_simple, values, ulabels, n_unique) | |
%timeit -r 10 call_grouped_sum(_grouped_sum_jit_conditional, values, labels, n_unique) | |
%timeit -r 10 call_grouped_sum(_grouped_sum_guvec_simple, values, labels, n_unique) | |
%timeit -r 10 call_grouped_sum(_grouped_sum_guvec_conditional, values, labels, n_unique) | |
# 1 loops, best of 10: 141 ms per loop | |
# 1 loops, best of 10: 128 ms per loop | |
# 1 loops, best of 10: 112 ms per loop | |
# 10 loops, best of 10: 165 ms per loop | |
# 10 loops, best of 10: 139 ms per loop |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment