Skip to content

Instantly share code, notes, and snippets.

@ahaldane
Last active February 13, 2016 16:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahaldane/1e673d2fe6ffe0be4f21 to your computer and use it in GitHub Desktop.
Save ahaldane/1e673d2fe6ffe0be4f21 to your computer and use it in GitHub Desktop.

This PR defines a new indexing function "split_classes" to accompany the others, which, every once in a while, I've wished existed. It splits up elements from one array based on the 'classification' provided by another array. In its simplest form, it does this:

def split_classes(c, v):
    return [v[c == u] for u in unique(c)]

This implemenation has nagged me though because of performance: If c contains n unique values, this loops through the entire c and v arrays n times each, and creates n intermediate boolean arrays. For large v,c,n I've been hit by performance.

This PR gives a performance improvement by computing everything in a single pass with no intermediate boolean arrays, and for conveniance also allows choice of axis.

split_classes might be (roughly) thought of as a generalization of compress, which itself is a generalization of extract, which is a generalization of boolean indexing. They often give the same result:

a = np.random.rand(100)
a[a > 0.5]
extract(a > 0.5, a)
compress(a > 0.5, a)
split_classes(a > 0.5, a)[1]

A few example uses:

from numpy.random import rand, choice, randint

# Example 1
data = rand(100,2)
lo, hi = split_classes(data[:,0] > 0.5, data)

# Example 2
classes = (data[:,0] < 0.5) + 2*(data[:,1] < 0.5)
group1, group2, group3, group4 = split_classes(classes, data)

# Example 3
years = [2010, 2011, 2012, 2013, 2014]
data = array([(choice(years), rand()) for i in range(100)], dtype=[('year', 'i4'), ('x', 'f4')])
for cat_data in split_classes(data['year'], data):
    print sum(cat_data['x'])

# Example 4
L = 100
seqs = randint(0, 4, size=(1000, L)) # represents a DNA multiple sequence alignment
phenotype = rand(len(seqs))
signal = [[np.mean(c) for c in split_classes(seqs[:,i], phenotype)] for i in range(L)]

A few related stackoverflow questions:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment