Forked from nguyenvulebinh/flatten_all_spark_schema.py
Last active
March 21, 2022 20:46
-
-
Save AxREki/70988ae40d1d82db15832575c175c41c to your computer and use it in GitHub Desktop.
Flatten a Spark DataFrame schema (include struct and array type)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing as T | |
import cytoolz.curried as tz | |
import pyspark | |
from pyspark.sql.functions import explode | |
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]: | |
columns = list() | |
def helper(schm: pyspark.sql.types.StructType, prefix: list = None): | |
if prefix is None: | |
prefix = list() | |
for item in schm.fields: | |
if isinstance(item.dataType, pyspark.sql.types.StructType): | |
helper(item.dataType, prefix + [item.name]) | |
else: | |
columns.append(prefix + [item.name]) | |
helper(schema) | |
return columns | |
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType): | |
have_array = False | |
aliased_columns = list() | |
i=0 | |
for column, t_column in frame.dtypes: | |
if t_column.startswith('array<') and i == 0: | |
have_array = True | |
c = explode(frame[column]).alias(column) | |
i = i+ 1 | |
else: | |
c = tz.get_in([column], frame) | |
aliased_columns.append(c) | |
return (frame.select(aliased_columns), have_array) | |
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
aliased_columns = list() | |
for col_spec in schema_to_columns(frame.schema): | |
c = tz.get_in(col_spec, frame) | |
if len(col_spec) == 1: | |
aliased_columns.append(c) | |
else: | |
aliased_columns.append(c.alias(':'.join(col_spec))) | |
return frame.select(aliased_columns) | |
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
frame = flatten_frame(frame) | |
(frame, have_array) = flatten_array(frame) | |
if have_array: | |
return flatten_all(frame) | |
else: | |
return frame |
I am able to run this successfully but I am getting duplicate values, the number of rows after flattening are getting doubled. Does anyone know why?
rows with NULL values are disappearing also for me.
This one https://gist.github.com/nmukerje/e65cde41be85470e4b8dfd9a2d6aed50 have the explode uter fix
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, the entire row, c.f. the spark.sql module's documentation