Skip to content

Instantly share code, notes, and snippets.

@javadba javadba/numpy-groupby.py
Last active May 20, 2019

Embed
What would you like to do?
numpy groupby
import numpy as np
from typing import AnyStr, Callable
AT = AnyStr
def groupby(arr: np.array, transformfn: Callable[[AT], AT], selectfn: Callable[[AT], AT] = None):
xarr = list(map(transformfn,arr))
keys, indx = np.unique(xarr, return_inverse=True)
K = len(keys)
recsByKey = list()
for i in range(K):
if selectfn:
recsByKey.append([selectfn(x) for x in arr[np.where(indx == i)]])
else:
recsByKey.append(arr[np.where(indx == i)])
return {k:v for k,v in list(zip(keys,recsByKey))}
# testing
x1 = ['a', 'b', 'a', 'a', 'b']
x2 = ['c', 'd', 'c', 'f', 'f']
x = np.column_stack([x1,x2])
gout = groupby(x, lambda y: y[0], lambda y: y[1])
print(repr(gout))
# {'a': ['c', 'c', 'f'],
# 'b': ['d', 'f']}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.