Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created October 15, 2014 21:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/2e2ddf44328e2ca0f273 to your computer and use it in GitHub Desktop.
Save shoyer/2e2ddf44328e2ca0f273 to your computer and use it in GitHub Desktop.
"""Proof of concept for implementing bottleneck style aggregators with numba
These functions aggregate over any number of axes, just like the built-in
numpy functions, e.g.,
- nansum(x) # aggregates over all axes
- nansum(x, axis=1) # aggregates over axis 1
- nansum(x, axis=(0, 2)) # aggregtes over axes 0 and 2
"""
import functools
import numba
import numpy as np
def ndreduce(arg, *args, **kwargs):
if callable(arg) and not args and not kwargs:
return NumbaAggregator(arg)
else:
return lambda func: NumbaAggregator(func, arg, *args, **kwargs)
def _axis_wrap(axis, ndim):
if axis < 0:
axis += ndim
if axis < 0 or axis >= ndim:
raise ValueError('invalid axis %s' % axis)
return axis
class NumbaAggregator(object):
def __init__(self, func, dtype_map=['float64->float64']):
self.func = func
self.dtype_map = dtype_map
self.cache = {}
def _create_gufunc(self, ndim):
# creating compiling gufunc has some significant overhead (~130ms per
# function and number of dimensions to aggregate), so do this in a
# lazy fashion
colons = ','.join(':' for _ in range(ndim))
dtype_str = []
for d in self.dtype_map:
k, v = d.split('->')
dtype_str.append('void(%s[%s], %s[:])' % (k, colons, v))
sig = '(%s)->()' % ','.join(list('abcdefgijk')[:ndim])
return numba.guvectorize(dtype_str, sig)(self.func)
def _get_gufunc(self, ndim):
if ndim not in self.cache:
self.cache[ndim] = self._create_gufunc(ndim)
return self.cache[ndim]
def __call__(self, arr, axis=None):
if axis is None:
# note: could also just use @jit instead of guvectorize
# for this case
axis = range(arr.ndim)
if np.isscalar(axis):
axis = [axis]
axis = [_axis_wrap(a, arr.ndim) for a in axis]
all_axes = [n for n in range(arr.ndim)
if n not in axis] + list(axis)
arr = arr.transpose(all_axes)
f = self._get_gufunc(len(axis))
return f(arr)
@ndreduce
def nansum(a, out):
asum = 0.0
for ai in a.flat:
if not np.isnan(ai):
asum += ai
out[0] = asum
# note: it would be nice to have some sort of wrapper so we could write
# these functions in more standard form, that is, returning out instead
# of using it as a parameter:
#
# @ndreduce
# def nansum(a):
# asum = 0.0
# for ai in a.flat:
# if not np.isnan(ai):
# asum += ai
# return asum
@ndreduce
def nanmean(a, out):
asum = 0.0
count = 0
for ai in a.flat:
if not np.isnan(ai):
asum += ai
count += 1
if count > 0:
y = asum / count
else:
y = np.nan
out[0] = y
@ndreduce
def nanmin(a, out):
amin = np.infty
allnan = 1
for ai in a.flat:
if ai <= amin:
amin = ai
allnan = 0
if allnan:
amin = np.nan
out[0] = amin
@ndreduce(['float32->int64', 'float64->int64'])
def count(a, out):
acount = 0
for ai in a.flat:
if not np.isnan(ai):
acount += 1
out[0] = acount
# The MIT License (MIT)
#
# Copyright (c) 2014 Stephan Hoyer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment