Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save bjornjorgensen/460eb8afb996b5bc3c8d6d2b6494123a to your computer and use it in GitHub Desktop.
Save bjornjorgensen/460eb8afb996b5bc3c8d6d2b6494123a to your computer and use it in GitHub Desktop.
Flatten a Spark DataFrame schema (include struct and array type)
import typing as T
import cytoolz.curried as tz
import pyspark
from pyspark.sql.functions import explode_outer
from pyspark.sql.types import BooleanType
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
columns = list()
def helper(schm: pyspark.sql.types.StructType, prefix: list = None):
if prefix is None:
prefix = list()
for item in schm.fields:
if isinstance(item.dataType, pyspark.sql.types.StructType):
helper(item.dataType, prefix + [item.name])
else:
columns.append(prefix + [item.name])
helper(schema)
return columns
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType):
have_array = False
aliased_columns = list()
i=0
for column, t_column in frame.dtypes:
if t_column.startswith('array<') and i == 0:
have_array = True
c = explode_outer(frame[column]).alias(column)
i = i+ 1
else:
c = tz.get_in([column], frame)
aliased_columns.append(c)
return (frame.select(aliased_columns), have_array)
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
aliased_columns = list()
for col_spec in schema_to_columns(frame.schema):
c = tz.get_in(col_spec, frame)
if len(col_spec) == 1:
aliased_columns.append(c)
else:
aliased_columns.append(c.alias(":".join(col_spec)))
return frame.select(aliased_columns)
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
frame = flatten_frame(frame)
(frame, have_array) = flatten_array(frame)
if have_array:
return flatten_all(frame)
else:
return frame
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment