Skip to content

Instantly share code, notes, and snippets.

Created July 18, 2024 11:43
Show Gist options
  • Save thangarajan8/2adfe8c041315a9ad57e499b46b29a15 to your computer and use it in GitHub Desktop.
Save thangarajan8/2adfe8c041315a9ad57e499b46b29a15 to your computer and use it in GitHub Desktop.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum, when
# Initialize Spark session
# spark = SparkSession.builder \
# .appName("Sales Analysis") \
# .getOrCreate()
# Sample data
data = [
(100, 'a', 'b', 'c', 's1','p1'),
(200, 'a', 'b', 'c', 's2','p1'),
(300, 'a', 'b', 'c', 's3','p2'),
(100, 'd', 'e', 'f', 's4','p2'),
(100, 'd', 'e', 'f', 's5','p3'),
(100, 'd', 'e', 'f', 's6','p4')
# Define schema
schema = ["sales", "region", "country", "city", "stage",'pincode']
# Create DataFrame
df = spark.createDataFrame(data, schema=schema)
def create_new_column(df, filter_cols, filter_vals, agg_col, new_col):
Function to filter, aggregate, and create/update a new column based on conditions in PySpark DataFrame.
- df (DataFrame): Input PySpark DataFrame.
- filter_cols (list of str): List of column names for filtering.
- filter_vals (list or tuple): List or tuple of values corresponding to filter_cols.
- agg_col (str): Column to aggregate.
- new_col (str): Name of the new column to be created or updated.
- DataFrame: DataFrame with the updated/new column added based on the conditions.
# Construct filter conditions
filter_condition = None
for col_name, col_val in zip(filter_cols, filter_vals):
if filter_condition is None:
filter_condition = col(col_name) == col_val
filter_condition = filter_condition & (col(col_name) == col_val)
# Filter the data
filtered_data = df.filter(filter_condition)
# Group by filter columns and aggregate
grouped_data = filtered_data.groupBy(*filter_cols) \
# Set the new column based on the condition
result = filtered_data.join(grouped_data, filter_cols, 'left_outer') \
.withColumn(new_col, when(col('total_sales') > 500, 'OK').otherwise('Not OK')) \
return result
# Example usage: Filter on region=a, country=b, city=c, aggregate sales, and create/update a new column based on the condition
fil = ['a','b','c','p1']
# fil1 = ['d', 'e', 'f','']
filtered_result = create_new_column(df, ['region', 'country','city','pincode'], fil , 'sales', 'new_column')
# Show the final result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment