Created
October 15, 2014 21:09
-
-
Save shoyer/2e2ddf44328e2ca0f273 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Proof of concept for implementing bottleneck style aggregators with numba | |
These functions aggregate over any number of axes, just like the built-in | |
numpy functions, e.g., | |
- nansum(x) # aggregates over all axes | |
- nansum(x, axis=1) # aggregates over axis 1 | |
- nansum(x, axis=(0, 2)) # aggregtes over axes 0 and 2 | |
""" | |
import functools | |
import numba | |
import numpy as np | |
def ndreduce(arg, *args, **kwargs): | |
if callable(arg) and not args and not kwargs: | |
return NumbaAggregator(arg) | |
else: | |
return lambda func: NumbaAggregator(func, arg, *args, **kwargs) | |
def _axis_wrap(axis, ndim): | |
if axis < 0: | |
axis += ndim | |
if axis < 0 or axis >= ndim: | |
raise ValueError('invalid axis %s' % axis) | |
return axis | |
class NumbaAggregator(object): | |
def __init__(self, func, dtype_map=['float64->float64']): | |
self.func = func | |
self.dtype_map = dtype_map | |
self.cache = {} | |
def _create_gufunc(self, ndim): | |
# creating compiling gufunc has some significant overhead (~130ms per | |
# function and number of dimensions to aggregate), so do this in a | |
# lazy fashion | |
colons = ','.join(':' for _ in range(ndim)) | |
dtype_str = [] | |
for d in self.dtype_map: | |
k, v = d.split('->') | |
dtype_str.append('void(%s[%s], %s[:])' % (k, colons, v)) | |
sig = '(%s)->()' % ','.join(list('abcdefgijk')[:ndim]) | |
return numba.guvectorize(dtype_str, sig)(self.func) | |
def _get_gufunc(self, ndim): | |
if ndim not in self.cache: | |
self.cache[ndim] = self._create_gufunc(ndim) | |
return self.cache[ndim] | |
def __call__(self, arr, axis=None): | |
if axis is None: | |
# note: could also just use @jit instead of guvectorize | |
# for this case | |
axis = range(arr.ndim) | |
if np.isscalar(axis): | |
axis = [axis] | |
axis = [_axis_wrap(a, arr.ndim) for a in axis] | |
all_axes = [n for n in range(arr.ndim) | |
if n not in axis] + list(axis) | |
arr = arr.transpose(all_axes) | |
f = self._get_gufunc(len(axis)) | |
return f(arr) | |
@ndreduce | |
def nansum(a, out): | |
asum = 0.0 | |
for ai in a.flat: | |
if not np.isnan(ai): | |
asum += ai | |
out[0] = asum | |
# note: it would be nice to have some sort of wrapper so we could write | |
# these functions in more standard form, that is, returning out instead | |
# of using it as a parameter: | |
# | |
# @ndreduce | |
# def nansum(a): | |
# asum = 0.0 | |
# for ai in a.flat: | |
# if not np.isnan(ai): | |
# asum += ai | |
# return asum | |
@ndreduce | |
def nanmean(a, out): | |
asum = 0.0 | |
count = 0 | |
for ai in a.flat: | |
if not np.isnan(ai): | |
asum += ai | |
count += 1 | |
if count > 0: | |
y = asum / count | |
else: | |
y = np.nan | |
out[0] = y | |
@ndreduce | |
def nanmin(a, out): | |
amin = np.infty | |
allnan = 1 | |
for ai in a.flat: | |
if ai <= amin: | |
amin = ai | |
allnan = 0 | |
if allnan: | |
amin = np.nan | |
out[0] = amin | |
@ndreduce(['float32->int64', 'float64->int64']) | |
def count(a, out): | |
acount = 0 | |
for ai in a.flat: | |
if not np.isnan(ai): | |
acount += 1 | |
out[0] = acount | |
# The MIT License (MIT) | |
# | |
# Copyright (c) 2014 Stephan Hoyer | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment