numpy groupby
import numpy as np | |
from typing import AnyStr, Callable | |
AT = AnyStr | |
def groupby(arr: np.array, transformfn: Callable[[AT], AT], selectfn: Callable[[AT], AT] = None): | |
xarr = list(map(transformfn,arr)) | |
keys, indx = np.unique(xarr, return_inverse=True) | |
K = len(keys) | |
recsByKey = list() | |
for i in range(K): | |
if selectfn: | |
recsByKey.append([selectfn(x) for x in arr[np.where(indx == i)]]) | |
else: | |
recsByKey.append(arr[np.where(indx == i)]) | |
return {k:v for k,v in list(zip(keys,recsByKey))} | |
# testing | |
x1 = ['a', 'b', 'a', 'a', 'b'] | |
x2 = ['c', 'd', 'c', 'f', 'f'] | |
x = np.column_stack([x1,x2]) | |
gout = groupby(x, lambda y: y[0], lambda y: y[1]) | |
print(repr(gout)) | |
# {'a': ['c', 'c', 'f'], | |
# 'b': ['d', 'f']} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment