Skip to content

Instantly share code, notes, and snippets.

@stillmatic
Last active February 6, 2020 05:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stillmatic/c230b5c21918893d8b4f489cb919bfcf to your computer and use it in GitHub Desktop.
Save stillmatic/c230b5c21918893d8b4f489cb919bfcf to your computer and use it in GitHub Desktop.
fast numpy groupby + listagg
import pandas as pd
import numpy as np
def fast_groupby_listagg(df, grouping_idx, presorted):
"""
given a grouping column, (or set of columns), we find all grouping sets.
note that this code expects a 2-d array and outputs a 3-d array.
we also convert values to float32 (trading precision for performance).
complexity analysis: O(nlogn + n + k)
n = # of rows in array
k = # of unique grouping sets (assume k << n)
more importantly, the only additional memory usage is storing the grouped indices,
which is linear with the number of unique grouping sets.
"""
# O(nlogn)
if not presorted:
df = df.sort_values(grouping_idx)
# O(n), I think
counts = df.groupby(grouping_idx)[df.columns[-1]].count()
# O(n), pandas count only counts non-null values, wtf
counts = counts.apply(lambda x: np.max([1, x])).values
# not sure, depends on pandas <> numpy efficiency
arr = df.drop(grouping_idx, axis=1).values
# O(k) - len(count_values_prefixed) == # of unique grouping sets
count_values_prefixed = np.insert(counts, 0, 0)
# O(k)
grouped_indices = list(
zip(np.cumsum(count_values_prefixed), np.cumsum(count_values_prefixed[1:]))
)
# O(k)
return np.asarray([arr[begin:end] for begin, end in grouped_indices])
def slow_groupby_listagg(df, grouping_idx):
"""Should return basically same output as above but slower/more memory overhead."""
return df.groupby(grouping_idx).agg(list).to_numpy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment