Skip to content

Instantly share code, notes, and snippets.

@EricSchles
Created March 28, 2022 13:38
Show Gist options
  • Save EricSchles/51796e7b834c936551235257651df8d4 to your computer and use it in GitHub Desktop.
Save EricSchles/51796e7b834c936551235257651df8d4 to your computer and use it in GitHub Desktop.
from pyspark.sql.functions import col
def groupby(df, columns):
sdf = df.to_spark()
_groups = sdf.select(*columns).distinct().collect()
_groups = [group.asDict() for group in _groups]
groups = []
for group in _groups:
tmp = []
for column in columns:
tmp.append(group[column])
groups.append(tmp)
for group in groups:
mask = (col(columns[0]) == group[0])
for index, column in enunerate(columns):
if index == 0:
continue
mask &= (col(column) == group[index])
yield (group, sdf.filter(mask).to_pandas_on_spark())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment