Skip to content

Instantly share code, notes, and snippets.

@ecgill
Created February 17, 2018 19:18
Show Gist options
  • Save ecgill/e1af78cff1b2725524f71df0fab9dd17 to your computer and use it in GitHub Desktop.
Save ecgill/e1af78cff1b2725524f71df0fab9dd17 to your computer and use it in GitHub Desktop.
Useful Spark snippets
# Starting spark session
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Name').getOrCreate()
# read in data from various formats:
df = spark.read.csv('filename')
df = spark.read.json('filename', inferSchema=True, header=False)
# change schema and read in data:
from pyspark.sql.types import StructField, StringType, IntegerType, StructType
data_schema = [StructField('name1', IntegerType(), True),
StructField('name2', StringType(), True)]
final_struct = StructType(fields = data_schema)
df = spark.read.json('filename', schema=final_struct)
# look at data in various ways:
df.printSchema()
df.describe().show() # summary stats
df.show() # be careful, will print all
df.head()
df.head(10)
df.select(['col1', 'col2']).show(5)
# using SQL commands:
df.createOrReplaceTempView('temp_name')
results = spark.sql('SELECT * FROM temp_name') # insert any SQL command
results.show()
# add a new column:
df.withColumn('new_col_name', df['old_col']*2).show() # NOT an "in place" operation
# filtering (& = and, | = or, ~ = not)
df.filter('col1 < #').show()
df.filter('col1 < #').select(['col2', 'col3']).show()
df.filter(df['col1'] < 5).select(['col2', 'col3']).show()
df.filter((df['col1'] < 5) & (df['col2'] > 20)).select(['col1', 'col2']).show()
df.filter((df['col1'] < 5) | (df['col2'] > 20)).select(['col1', 'col2']).show()
df.filter((df['col1'] < 5) | ~(df['col2'] > 20)).select(['col1', 'col2']).show()
result = df.filter(df['col1'] == 10).collect() # returns a list of row object
row = result[0] # returns just row object that has many methods
row.asDict() # return row as a dictionary with key values as column headers
row.asDict()['col3'] # grabs a specific column volume
# groupby and aggregation
df.groupby('col1')
df.groubpy('col1').count().show() # also mean(), max(), min()
df.agg({'col1':'operation'}).show() # 'operation' = 'sum', 'max', 'min' ...etc
# order by, sorting
df.orderBy('col1').show() # automatically sorts by ASCENDING
df.orderBy(df['col1'].desc()).show()
# dropping null values (** as written, these are not 'in place' operations!)
df.na.drop(how='any') # the default. same as df.na.drop(). drops any null rows
df.na.drop(how='all') # if any row is all nulls, it gets dropped
df.na.drop(thresh=2) # drops any row with 2 or more nulls
df.na.drop(subset='col1') # drops rows with any null in 'col1'
# filling null values
df.na.fill('new_str', subset=['col_to_fill']) # can fill with 0 or anything
from pyspark.sql.functions import mean
mean_val = df.select(mean(df['col1'])).collect()
df.na.fill(mean_val, ['col_to_fill']).show() # fills with mean value
# dealing with date times
from pyspark.sql.functions import (dayofmonth, year, format_number)
df.select(dayofmonth(df['date_col'])) # replace dayofmonth with hour, month, year, etc.
new_df = df.withColumn('year', year(df['date_col'])) # appends year column
result = new_df.groupBy('year').mean().select('year', 'avg(col1)') # groupby year and calc means
result.select(['year', format_number('avg(col1)', 2).alias('Avg ColName')]) # cut digits down and alias col
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment