Skip to content

Instantly share code, notes, and snippets.

@lotusk
Forked from nmukerje/Pyspark Flatten json
Created February 20, 2020 08:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lotusk/1d6068563244d3f92813d97c4f6112fd to your computer and use it in GitHub Desktop.
Save lotusk/1d6068563244d3f92813d97c4f6112fd to your computer and use it in GitHub Desktop.
from pyspark.sql.types import *
from pyspark.sql.functions import *
#Flatten array of structs and structs
def flatten(df):
# compute Complex Fields (Lists and Structs) in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
while len(complex_fields)!=0:
col_name=list(complex_fields.keys())[0]
print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if (type(complex_fields[col_name]) == StructType):
expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in complex_fields[col_name]]]
df=df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif (type(complex_fields[col_name]) == ArrayType):
df=df.withColumn(col_name,explode(col_name))
# recompute remaining Complex Fields in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
return df
df=flatten(df)
df.printSchema()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment