Skip to content

Instantly share code, notes, and snippets.

@javadba
Last active May 20, 2019 06:17
Show Gist options
  • Save javadba/9e46c1e00274cd9c7cb422fa01cac9bd to your computer and use it in GitHub Desktop.
Save javadba/9e46c1e00274cd9c7cb422fa01cac9bd to your computer and use it in GitHub Desktop.
numpy groupby
import numpy as np
from typing import AnyStr, Callable
AT = AnyStr
def groupby(arr: np.array, transformfn: Callable[[AT], AT], selectfn: Callable[[AT], AT] = None):
xarr = list(map(transformfn,arr))
keys, indx = np.unique(xarr, return_inverse=True)
K = len(keys)
recsByKey = list()
for i in range(K):
if selectfn:
recsByKey.append([selectfn(x) for x in arr[np.where(indx == i)]])
else:
recsByKey.append(arr[np.where(indx == i)])
return {k:v for k,v in list(zip(keys,recsByKey))}
# testing
x1 = ['a', 'b', 'a', 'a', 'b']
x2 = ['c', 'd', 'c', 'f', 'f']
x = np.column_stack([x1,x2])
gout = groupby(x, lambda y: y[0], lambda y: y[1])
print(repr(gout))
# {'a': ['c', 'c', 'f'],
# 'b': ['d', 'f']}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment