Skip to content

Instantly share code, notes, and snippets.

@dineshdharme
Created September 6, 2023 14:50
Show Gist options
  • Save dineshdharme/d0247100dd0b29034e5bc46bc883504d to your computer and use it in GitHub Desktop.
Save dineshdharme/d0247100dd0b29034e5bc46bc883504d to your computer and use it in GitHub Desktop.
An example of PandasUDFType.GROUPED_AGG in pyspark. Clearly explained.
Here's a solution using `PandasUDFType.GROUPED_AGG` which can be used inside a groupby clause.
from pyspark import SQLContext
from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window
from typing import Iterator, Tuple
import pandas as pd
sc = SparkContext('local')
sqlContext = SQLContext(sc)
data1 = [
["01001", 12, 41, 10],
["01004", 66, 1, 77],
["01003", 31, 52, 10],
["01004", 27, 11, 91],
["01001", 43, 5, 10 ],
["01003", 21, 11, 2 ],
["01003", -61, 15, 10],
["01001", 67, -11, -22],
["01004", 21, -26, -13],
["01001", 13, -5, 10 ],
["01003", 21, 111, -2 ],
["01003", 13, 18, 10],
["01001", 49, -17, -22],
]
df1Columns = ["row_id", "col1", "col2", "col3"]
df1 = sqlContext.createDataFrame(data=data1, schema = df1Columns)
print("Given dataframe")
df1.show(n=100, truncate=False)
schema = StructType([StructField('col1', IntegerType()),
StructField('col2', IntegerType()),
StructField('col3', IntegerType())])
@pandas_udf(ArrayType(ArrayType(IntegerType())), PandasUDFType.GROUPED_AGG)
def custom_sum_udf(col1_series: pd.Series, col2_series: pd.Series, col3_series: pd.Series) -> ArrayType(ArrayType(IntegerType())):
concat_df = pd.concat([col1_series, col2_series, col3_series], axis=1)
print("what is the concat df")
print(concat_df)
sum_column = concat_df.sum(axis=0).tolist()
max_column = concat_df.max(axis=0).tolist()
min_column = concat_df.min(axis=0).tolist()
print("sum_column", sum_column)
print("max_column", max_column)
print("min_column", min_column)
all_result = [sum_column, max_column, min_column]
return all_result
df_new = df1.groupby(F.col("row_id")).agg(custom_sum_udf(F.col("col1"), F.col("col2"), F.col("col3")).alias("reduced_columns")).cache()
print("Printing the column sum, max, min")
df_new.show(n=100, truncate=False)
df_new_sep = df_new.withColumn("sum_over_columns", F.col("reduced_columns").getItem(0))
df_new_sep = df_new_sep.withColumn("max_over_columns", F.col("reduced_columns").getItem(1))
df_new_sep = df_new_sep.withColumn("min_over_columns", F.col("reduced_columns").getItem(2)).drop(F.col("reduced_columns"))
print("Printing the column sum, max, min")
df_new_sep.show(n=100, truncate=False)
Output :
Given dataframe
+------+----+----+----+
|row_id|col1|col2|col3|
+------+----+----+----+
|01001 |12 |41 |10 |
|01004 |66 |1 |77 |
|01003 |31 |52 |10 |
|01004 |27 |11 |91 |
|01001 |43 |5 |10 |
|01003 |21 |11 |2 |
|01003 |-61 |15 |10 |
|01001 |67 |-11 |-22 |
|01004 |21 |-26 |-13 |
|01001 |13 |-5 |10 |
|01003 |21 |111 |-2 |
|01003 |13 |18 |10 |
|01001 |49 |-17 |-22 |
+------+----+----+----+
+------+-----------------------------------------------+
|row_id|reduced_columns |
+------+-----------------------------------------------+
|01001 |[[184, 13, -14], [67, 41, 10], [12, -17, -22]] |
|01003 |[[25, 207, 30], [31, 111, 10], [-61, 11, -2]] |
|01004 |[[114, -14, 155], [66, 11, 91], [21, -26, -13]]|
+------+-----------------------------------------------+
Printing the column sum, max, min
+------+----------------+----------------+----------------+
|row_id|sum_over_columns|max_over_columns|min_over_columns|
+------+----------------+----------------+----------------+
|01001 |[184, 13, -14] |[67, 41, 10] |[12, -17, -22] |
|01003 |[25, 207, 30] |[31, 111, 10] |[-61, 11, -2] |
|01004 |[114, -14, 155] |[66, 11, 91] |[21, -26, -13] |
+------+----------------+----------------+----------------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment