Forked from nguyenvulebinh/flatten_all_spark_schema.py
Last active
March 21, 2022 20:46
-
-
Save AxREki/70988ae40d1d82db15832575c175c41c to your computer and use it in GitHub Desktop.
Flatten a Spark DataFrame schema (include struct and array type)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing as T | |
import cytoolz.curried as tz | |
import pyspark | |
from pyspark.sql.functions import explode | |
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]: | |
columns = list() | |
def helper(schm: pyspark.sql.types.StructType, prefix: list = None): | |
if prefix is None: | |
prefix = list() | |
for item in schm.fields: | |
if isinstance(item.dataType, pyspark.sql.types.StructType): | |
helper(item.dataType, prefix + [item.name]) | |
else: | |
columns.append(prefix + [item.name]) | |
helper(schema) | |
return columns | |
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType): | |
have_array = False | |
aliased_columns = list() | |
i=0 | |
for column, t_column in frame.dtypes: | |
if t_column.startswith('array<') and i == 0: | |
have_array = True | |
c = explode(frame[column]).alias(column) | |
i = i+ 1 | |
else: | |
c = tz.get_in([column], frame) | |
aliased_columns.append(c) | |
return (frame.select(aliased_columns), have_array) | |
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
aliased_columns = list() | |
for col_spec in schema_to_columns(frame.schema): | |
c = tz.get_in(col_spec, frame) | |
if len(col_spec) == 1: | |
aliased_columns.append(c) | |
else: | |
aliased_columns.append(c.alias(':'.join(col_spec))) | |
return frame.select(aliased_columns) | |
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
frame = flatten_frame(frame) | |
(frame, have_array) = flatten_array(frame) | |
if have_array: | |
return flatten_all(frame) | |
else: | |
return frame |
thanks for your feedback @vchalmel
The entire row is being deleted ? Even if there was data in other columns ?
Do you have a sample I could try it on to improve the gist ?
Yes, the entire row, c.f. the spark.sql module's documentation
pyspark.sql.functions.explode_outer(col)
Returns a new row for each element in the given array or map. Unlike explode, if the array/map is null or empty then null is produced.
df = spark.createDataFrame([(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],("id", "an_array", "a_map"))
df.select("id", "an_array", explode("a_map")).show()
df.select("id", "an_array", explode_outer("a_map")).show()
df.select("id", "a_map", explode("an_array")).show()
df.select("id", "a_map", explode_outer("an_array")).show()
I am able to run this successfully but I am getting duplicate values, the number of rows after flattening are getting doubled. Does anyone know why?
rows with NULL values are disappearing also for me.
This one https://gist.github.com/nmukerje/e65cde41be85470e4b8dfd9a2d6aed50 have the explode uter fix
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi ! I tested your script and some rows are lost in the process. I tested on a small dataframe (16 rows as input ) with a unique identifier column, and from the 16 different ids, only 3 are found in the output dataframe.
I observed that when a struct or array column in the input dataframe has null values the rows having these nulls are deleted
edit : it's the use of explode that deletes null values in array columns, replace it with explode_outer to handle such cases.