Skip to content

Instantly share code, notes, and snippets.

@juliantaylor
Last active December 23, 2015 13:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save juliantaylor/6642848 to your computer and use it in GitHub Desktop.
Save juliantaylor/6642848 to your computer and use it in GitHub Desktop.
generalized axis support for arbitraryy funcions
import numpy as np
def ureduce(a, func=np.median, **kwargs):
a = np.asarray(a)
axis = kwargs.get('axis', None)
keepdims = kwargs.pop('keepdims', False)
if axis is not None:
keepdim = list(a.shape)
nd = a.ndim
if np.isscalar(axis):
if axis >= nd or axis < -nd:
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
keepdim[axis] = 1
else:
sax = set()
for x in axis:
if x >= nd or x < -nd:
raise IndexError("axis %d out of bounds (%d)" % (x, nd))
if x in sax:
raise ValueError("duplicate value in axis")
sax.add(x % nd)
keepdim[x] = 1
keep = sax.symmetric_difference(frozenset(range(nd)))
nkeep = len(keep)
# swap axis that should not be reduced to front
for i, s in enumerate(sorted(keep)):
a = a.swapaxes(i, s)
# merge reduced axis
a = a.reshape(a.shape[:nkeep] + (np.prod(a.shape[nkeep:]),))
kwargs['axis'] = -1
else:
keepdim = [1] * a.ndim
r = func(a, **kwargs)
if keepdims:
return r.reshape(keepdim)
else:
return r
x = np.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6)
print ureduce(x, func=np.percentile, q=(25,50), axis=(1,3)).shape
d = np.random.normal(size=(70,30))
x = np.array([d] * 10)
print x.shape
r = ureduce(x, func=np.median, axis=(-2, -1))
print r, r[0]
print np.median(d)
def median2(a, **kwargs):
return np.percentile(a, 50, **kwargs)
r = ureduce(x, func=median2, axis=(-2, -1))
print r, r[0]
print np.median(d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment