Skip to content

Instantly share code, notes, and snippets.

@nguyenvulebinh
Last active August 8, 2023 15:08
Show Gist options
  • Save nguyenvulebinh/794c296b1133feb80e46e812ef50f7fc to your computer and use it in GitHub Desktop.
Save nguyenvulebinh/794c296b1133feb80e46e812ef50f7fc to your computer and use it in GitHub Desktop.
Flatten a Spark DataFrame schema (include struct and array type)
import typing as T
import cytoolz.curried as tz
import pyspark
from pyspark.sql.functions import explode
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
columns = list()
def helper(schm: pyspark.sql.types.StructType, prefix: list = None):
if prefix is None:
prefix = list()
for item in schm.fields:
if isinstance(item.dataType, pyspark.sql.types.StructType):
helper(item.dataType, prefix + [item.name])
else:
columns.append(prefix + [item.name])
helper(schema)
return columns
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType):
have_array = False
aliased_columns = list()
for column, t_column in frame.dtypes:
if t_column.startswith('array<'):
have_array = True
c = explode(frame[column]).alias(column)
else:
c = tz.get_in([column], frame)
aliased_columns.append(c)
return (frame.select(aliased_columns), have_array)
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
aliased_columns = list()
for col_spec in schema_to_columns(frame.schema):
c = tz.get_in(col_spec, frame)
if len(col_spec) == 1:
aliased_columns.append(c)
else:
aliased_columns.append(c.alias(':'.join(col_spec)))
return frame.select(aliased_columns)
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
frame = flatten_frame(frame)
(frame, have_array) = flatten_array(frame)
if have_array:
return flatten_all(frame)
else:
return frame
@YoussefEssa
Copy link

YoussefEssa commented Sep 25, 2018

Hey,
could you please help by giving an example how to add this into project and how to use it in spark?

I tried but I faced:
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
^
SyntaxError: invalid syntax

I'm using python 3.7.0

@Helpfulpaw
Copy link

Has problems with two explodes

@AxREki
Copy link

AxREki commented Aug 6, 2019

@Helpfupaw, Try this from my fork :
https://gist.github.com/AxREki/70988ae40d1d82db15832575c175c41c

def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType):
    have_array = False
    aliased_columns = list()
    i=0
    for column, t_column in frame.dtypes:
        if t_column.startswith('array<') and i == 0:
            have_array = True 
            c = explode(frame[column]).alias(column)            
            i = i+ 1
        else:
            c = tz.get_in([column], frame)
        aliased_columns.append(c)
    
     
    return (frame.select(aliased_columns), have_array)

@reddysk18
Copy link

HI,
I am getting error below error. I am using this script in azure databricks. Can someone suggest?
"No module named 'cytoolz'".

@AxREki
Copy link

AxREki commented Aug 20, 2020

@reddysk18 You should install cytoolz on the cluster first

@reddysk18
Copy link

reddysk18 commented Aug 21, 2020 via email

@AxREki
Copy link

AxREki commented Aug 21, 2020

@reddysk18 have you tried https://gist.github.com/AxREki/70988ae40d1d82db15832575c175c41c ?
I solve a couple issues there.

@reddysk18
Copy link

reddysk18 commented Aug 21, 2020 via email

@AxREki
Copy link

AxREki commented Aug 21, 2020

@reddysk18 Can you share the schema before and after flatenning the dataframe ?

Copy link

ghost commented Nov 6, 2020

Hello @AxREki, i am seeing te same issue as @reddysk18. the dataframe has been successfully flatten but no data is being displayed.
DF before:
|-- Portfolio: struct (nullable = true)
|-- HoldingBreakdown: struct (nullable = true)
| | |-- Holding: struct (nullable = true)
| | | |-- HoldingDetail: array (nullable = true)
| | | | |-- element: struct (containsNull = true)
| | | | | |-- HoldingDetailId: string (nullable = true)
| | | | | |-- DetailHoldingTypeId: string (nullable = true)
| | | | | |-- Country: struct (nullable = true)
| | | | | | |-- _VALUE: string (nullable = true)
| | | | | | |-- __Id: string (nullable = false)

DF after:

|-- Portfolio_HoldingBreakdown_Holding_HoldingDetail_HoldingDetailId: string (nullable = true)
|-- Portfolio_HoldingBreakdown_Holding_HoldingDetail_DetailHoldingTypeId: string (nullable = true)
|-- Portfolio_HoldingBreakdown_Holding_HoldingDetail_Country__VALUE: string (nullable = true)
|-- Portfolio_HoldingBreakdown_Holding_HoldingDetail_Country___Id: string (nullable = true)

flattened_df.count()
0
df.count()
1025

@AxREki
Copy link

AxREki commented Nov 6, 2020

@pingilis
Have you tried this : https://gist.github.com/AxREki/70988ae40d1d82db15832575c175c41c

I solve some issues there about empty dataframes

Copy link

ghost commented Nov 6, 2020

Hi @AxREki, yes i have tried the updated gist. still i see no data.

if t_column.startswith('array<') and i == 0:

I have tried an other way around to flatten which worked but still do not see any data with the data frame after flattening.
used the below code. i am new to python so could not understand the breakdown.

def flatten(schema, prefix=None):
    fields = []
    for field in schema.fields:
        name = prefix + '.' + field.name if prefix else field.name
        dtype = field.dataType
        if isinstance(dtype, StructType):
            fields += flatten(dtype, prefix=name)
        else:
            fields.append(name)
    return fields

def explodeDF(df):
    for (name, dtype) in df.dtypes:
        if "array" in dtype:
            df = df.withColumn(name, explode(name))
    return df

def df_is_flat(df):
    for (_, dtype) in df.dtypes:
        if ("array" in dtype) or ("struct" in dtype):
            return False
    return True

def flatJson(jdf):
    keepGoing = True
    while(keepGoing):
        fields = flatten(jdf.schema)
        new_fields = [item.replace(".", "_") for item in fields]
        jdf = jdf.select(fields).toDF(*new_fields)
        jdf = explodeDF(jdf)
        if df_is_flat(jdf):
            keepGoing = False
    return jdf

this too flattens the df but returns no data. could you know why is this ?

@AxREki
Copy link

AxREki commented Nov 9, 2020

@pingilis
I tried the code from this gist with the following data :

data = {
	"HoldinBreakdown": {
		"Holding": {
			"HoldingDetail": [{
					"HoldingDetailId": "Id_test",
				    "DetailHoldingTypeId": "Type_test",
				   	"Country": {
						"_VALUE": "France",
                        "ID":"un"
					}
				}
			]
		}
	}
}

And got this :
image

And the data seems flattened.
image

@AxREki
Copy link

AxREki commented Nov 9, 2020

Even your code seems to work

image

Do you have a sample of data you can share ?

Copy link

ghost commented Nov 18, 2020

Hello @AxREki, yes the code works when all the datasets have data in them. if in case there is no data the results is 0 rows all together.
For example if there is no data as mentioned below then it would not work.
data = {
"HoldinBreakdown": {
"Holding": {
"HoldingDetail": [{
"HoldingDetailId":
"DetailHoldingTypeId":
"Country": {
"_VALUE": "France",
"ID":"un"
}
}
]
}
}
}

Any way i have used explode_outer instead of just explode, which works well but while exploding arrays its creating duplicate records.

@EatZeBaby
Copy link

@ghost Thanks for your feedback. Glad to see you found a way.

@NataliaLaurova
Copy link

@ghost
I faced with a similar issue as you - I got the right structure and flatten schema but there is no data. The root cause of it is - explode function, it will drop records if there are nulls in the columns. I fixed that using outer_explode (that takes care of the null and behave as a left outer join in SQL). Hope this will help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment