Skip to content

Instantly share code, notes, and snippets.

@nmukerje
Last active February 20, 2024 07:26
Show Gist options
  • Star 40 You must be signed in to star a gist
  • Fork 10 You must be signed in to fork a gist
  • Save nmukerje/e65cde41be85470e4b8dfd9a2d6aed50 to your computer and use it in GitHub Desktop.
Save nmukerje/e65cde41be85470e4b8dfd9a2d6aed50 to your computer and use it in GitHub Desktop.
from pyspark.sql.types import *
from pyspark.sql.functions import *
#Flatten array of structs and structs
def flatten(df):
# compute Complex Fields (Lists and Structs) in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
while len(complex_fields)!=0:
col_name=list(complex_fields.keys())[0]
print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if (type(complex_fields[col_name]) == StructType):
expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in complex_fields[col_name]]]
df=df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif (type(complex_fields[col_name]) == ArrayType):
df=df.withColumn(col_name,explode_outer(col_name))
# recompute remaining Complex Fields in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
return df
df=flatten(df)
df.printSchema()
@wallacemreis
Copy link

Hi Thanks for your gist, I was looking for it on google

I made some changes and would like to share

from pyspark.sql import types as T
import pyspark.sql.functions as F


def flatten(df):
    complex_fields = dict([
        (field.name, field.dataType) 
        for field in df.schema.fields 
        if isinstance(field.dataType, T.ArrayType) or isinstance(field.dataType, T.StructType)
    ])
    
    qualify = list(complex_fields.keys())[0] + "_"

    while len(complex_fields) != 0:
        col_name = list(complex_fields.keys())[0]
        
        if isinstance(complex_fields[col_name], T.StructType):
            expanded = [F.col(col_name + '.' + k).alias(col_name + '_' + k) 
                        for k in [ n.name for n in  complex_fields[col_name]]
                       ]
            
            df = df.select("*", *expanded).drop(col_name)
    
        elif isinstance(complex_fields[col_name], T.ArrayType): 
            df = df.withColumn(col_name, F.explode(col_name))
    
      
        complex_fields = dict([
            (field.name, field.dataType)
            for field in df.schema.fields
            if isinstance(field.dataType, T.ArrayType) or isinstance(field.dataType, T.StructType)
        ])
        
        
    for df_col_name in df.columns:
        df = df.withColumnRenamed(df_col_name, df_col_name.replace(qualify, ""))

    return df

@lkhaskin
Copy link

I have tested it with schema:
root
|-- Results: struct (nullable = true)
| |-- series: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- data: array (nullable = true)
| | | | |-- element: struct (containsNull = true)
| | | | | |-- period: string (nullable = true)
| | | | | |-- periodName: string (nullable = true)
| | | | | |-- value: string (nullable = true)
| | | | | |-- year: string (nullable = true)
| | | |-- seriesID: string (nullable = true)
|-- message: array (nullable = true)
| |-- element: string (containsNull = true)
|-- responseTime: long (nullable = true)
|-- status: string (nullable = true)
it did flattened the schema, but returned no data

@alanflanders
Copy link

alanflanders commented May 23, 2020

Some of my arrays were null which caused rows to be dropped. Use explode_outer() to include rows that have a null value in an array field.

    # if ArrayType then add the Array Elements as Rows using the explode function
    # i.e. explode Arrays
    elif (type(complex_fields[col_name]) == ArrayType):
        df=df.withColumn(col_name,F.explode_outer(col_name))

Thanks for this script!

@richban
Copy link

richban commented Oct 29, 2020

Thanks for this wonderful function! 👍

@NiteshTripurana
Copy link

Thanks for the superb code.
But I have a requirement, wherein I have a complex JSON with130 Nested columns. For each of the Nested columns, I need to create a separate Dataframe. Using these seperate Dataframes, I can write it onto different files. Can you please help.

@shivakumarchintala
Copy link

shivakumarchintala commented Jan 2, 2021

works like magic. but can we use this code when we are using a cluster? . Please clarify
will there be any issue in updating the data frame in a loop as it is running over a cluster

@patirahardik
Copy link

Work like a charm... thanks for sharing....!!

@Atif8Ted
Copy link

Thanks for the superb code.
But I have a requirement, wherein I have a complex JSON with130 Nested columns. For each of the Nested columns, I need to create a separate Dataframe. Using these seperate Dataframes, I can write it onto different files. Can you please help.

I have done something like that ... in my previous org.

@ThiagoPositeli
Copy link

guys, when i try to flatte my json pass the dataframe, i got this error.
this code working beautiful for me before, but since some days ago, when i run this code i got this. Anyone can help me?


TypeError Traceback (most recent call last)
/tmp/ipykernel_2459/1833926046.py in
29 return df
30
---> 31 df=flatten(df3)
32 df.printSchema()

/tmp/ipykernel_2459/1833926046.py in flatten(df)
9 if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
10 while len(complex_fields)!=0:
---> 11 col_name=list(complex_fields.keys())[0]
12 print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
13

TypeError: 'str' object is not callable

@bjornjorgensen
Copy link

Yours columns is not a string.

@ThiagoPositeli
Copy link

ThiagoPositeli commented Feb 1, 2022

thanks @bjornjorgensen i have been run this code for 15 days in the same file in jupyter notebook and working just fine, but a few days ago i got this error in the same file.
And the weird thing, aftter a few hours or days the same code in the same file starting to working again.

this is my code before to pass to this function:
path_parquet = 'gs://bucket-raw-ge/raw-ge-files/part-00000-5c473dc6-fa7c-465d-a8c2-0a6ea2793a58-c000.snappy.parquet'
df3 = spark.read.parquet(path_parquet)
df3.printSchema()

and this is my json schema:

root
|-- meta: struct (nullable = true)
| |-- active_batch_definition: struct (nullable = true)
| | |-- batch_identifiers: struct (nullable = true)
| | | |-- batch_id: string (nullable = true)
| | |-- data_asset_name: string (nullable = true)
| | |-- data_connector_name: string (nullable = true)
| | |-- datasource_name: string (nullable = true)
| |-- batch_markers: struct (nullable = true)
| | |-- ge_load_time: string (nullable = true)
| |-- batch_spec: struct (nullable = true)
| | |-- batch_data: string (nullable = true)
| | |-- data_asset_name: string (nullable = true)
| |-- expectation_suite_name: string (nullable = true)
| |-- great_expectations_version: string (nullable = true)
| |-- run_id: struct (nullable = true)
| | |-- run_name: string (nullable = true)
| | |-- run_time: string (nullable = true)
| |-- validation_time: string (nullable = true)
|-- results: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- exception_info: struct (nullable = true)
| | | |-- exception_message: string (nullable = true)
| | | |-- exception_traceback: string (nullable = true)
| | | |-- raised_exception: boolean (nullable = true)
| | |-- expectation_config: struct (nullable = true)
| | | |-- expectation_context: struct (nullable = true)
| | | | |-- description: string (nullable = true)
| | | |-- expectation_type: string (nullable = true)
| | | |-- kwargs: struct (nullable = true)
| | | | |-- batch_id: string (nullable = true)
| | | | |-- column: string (nullable = true)
| | | | |-- column_list: array (nullable = true)
| | | | | |-- element: string (containsNull = true)
| | | | |-- column_set: array (nullable = true)
| | | | | |-- element: string (containsNull = true)
| | | | |-- parse_strings_as_datetimes: boolean (nullable = true)
| | | | |-- strictly: boolean (nullable = true)
| | | | |-- value: long (nullable = true)
| | | | |-- value_set: array (nullable = true)
| | | | | |-- element: string (containsNull = true)
| | |-- result: struct (nullable = true)
| | | |-- details: struct (nullable = true)
| | | | |-- mismatched: string (nullable = true)
| | | | |-- value_counts: array (nullable = true)
| | | | | |-- element: struct (containsNull = true)
| | | | | | |-- count: long (nullable = true)
| | | | | | |-- value: string (nullable = true)
| | | |-- element_count: long (nullable = true)
| | | |-- missing_count: long (nullable = true)
| | | |-- missing_percent: double (nullable = true)
| | | |-- observed_value: string (nullable = true)
| | | |-- partial_unexpected_counts: array (nullable = true)
| | | | |-- element: struct (containsNull = true)
| | | | | |-- count: long (nullable = true)
| | | | | |-- value: string (nullable = true)
| | | |-- partial_unexpected_index_list: string (nullable = true)
| | | |-- partial_unexpected_list: array (nullable = true)
| | | | |-- element: string (containsNull = true)
| | | |-- unexpected_count: long (nullable = true)
| | | |-- unexpected_percent: double (nullable = true)
| | | |-- unexpected_percent_nonmissing: double (nullable = true)
| | | |-- unexpected_percent_total: double (nullable = true)
| | |-- success: boolean (nullable = true)
|-- statistics: struct (nullable = true)
| |-- evaluated_expectations: long (nullable = true)
| |-- success_percent: double (nullable = true)
| |-- successful_expectations: long (nullable = true)
| |-- unsuccessful_expectations: long (nullable = true)
|-- success: boolean (nullable = true)

@bjornjorgensen
Copy link

Can you try gs://bucket-raw-ge/raw-ge-files/* spark splits up the dataframe to pieces sometime. And you are compressing these files with snappy. I dont think sparks see the whole dataframe that you are working on now.

@ThiagoPositeli
Copy link

same error @bjornjorgensen and i already try passing the json format and not the parquet format as well.
like:

json_file_path = 'gs://bucket-dataproc-ge/great_expectations/validations/teste_tabela_ge/none/20220201T131502.600412Z/b82d2b0bc0d5deaf8922db55075f898b.json'

df2 = spark.read.option("multiLine", "true").option("mode", "PERMISSIVE").option("inferSchema", "true").json(json_file_path)
df2.printSchema()

and passing df2 to the function.

I'm running this on a dataproc cluster with jupyter notebook, I don't know if this is important.
Image from cluster 2.0.27-debian10

@bjornjorgensen
Copy link

oh, when I use this function I disable line nr 12 place a # in front of it.

There an error in string https://itsmycode.com/python-typeerror-str-object-is-not-callable-solution/

@ThiagoPositeli
Copy link

thanks @bjornjorgensen for the help.
I just kill the cluster and created another and the function start running again with the same files 🤣

@Mavericks334
Copy link

Mavericks334 commented Aug 23, 2022

Hi,
I get the below error. Any suggestions how to avoid it. I have nodes that have multiple nodes within it. It could go upto 3 or 4 levels

root
|-- code: string (nullable = true)
|-- rule_id: string (nullable = true)
|-- from: date (nullable = true)
|-- _to: date (nullable = true)
|-- type: string (nullable = true)
|-- definition: string (nullable = true)
|-- description: string (nullable = true)
|-- created_on: timestamp (nullable = true)
|-- creator: string (nullable = true)
|-- modified_on: timestamp (nullable = true)
|-- modifier: string (nullable = true)

IndexError                                Traceback (most recent call last)
<command-4235716148475136> in <module>
----> 1 flatten_df = flatten(ar)
      2 flatten_df.show()

<command-4320286398733830> in flatten(df)
     10     ])
     11 
---> 12     qualify = list(complex_fields.keys())[0] + "_"
     13 
     14     while len(complex_fields) != 0:

IndexError: list index out of range

@ttdidier
Copy link

Line 21: #if ArrayType then add the Array Elements as Rows using the explode function
Is there a way we can add Array Elements as columns rather than rows.
Tried to use posexplode to later pivot the table without success.

I tried to use df=df.selectExpr("*", posexplode_outer(col_name).alias("position",col_name)) but getting error
"TypeError: Column is not iterable"

@bjornjorgensen
Copy link

I have updated this function and add a fix for mapType
I have also created a JIRA for this.

@franciscodara
Copy link

help me
When running the code above, everything worked, but I need to treat the data that comes in timestamp format to date. Can someone help me?

Exit:
_v | 0
action | To create
data_updatedAt | 2023-01-26T15:10:...
date
$date | 1674745838876 <<<<<<<

I need:
_v | 0
action | To create
data_updatedAt | 2023-01-26T15:10:...
date
$date | 2023-01-26T15:10:... <<<<<<<

@CMonte2
Copy link

CMonte2 commented Mar 9, 2023

Thanks a lot for your work, it works great.

@anayyar82
Copy link

Hello, I tried to use mapType in Spark Streaming but it's not working due to an issue in the code.

Below is the one giving issue while doing in Spark Streaming :

        keys = list(map(lambda row: row[0], keys_df.collect()))

Please let me know the best option to resolve it in Spark Structure Steaming.

@prafulacharya
Copy link

This function flatten(), fails when there is nested array inside array, It failed to flatten these, "user_mentions": [
{
"screen_name": "AshJone15461246",
"name": "Ash Jones",
"id": 1589564629369462800,
"id_str": "1589564629369462784",
"indices": [
0,
16
]
},
{
"screen_name": "BariAWilliams",
"name": "Bärí A. Williams, Esq.",
"id": 4639656854,
"id_str": "4639656854",
"indices": [
17,
31
]
},
{
"screen_name": "bjorn_hefnoll",
"name": "Björn",
"id": 1374096417954881500,
"id_str": "1374096417954881548",
"indices": [
32,
46
]
},
{
"screen_name": "SpencerAlthouse",
"name": "Spencer Althouse",
"id": 38307346,
"id_str": "38307346",
"indices": [
47,
63
]
}
].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment