Skip to content

Instantly share code, notes, and snippets.

@shreyasms17
Last active September 2, 2022 10:22
Show Gist options
  • Save shreyasms17/e6b8984c4c20cfa54f5fb55810ba068e to your computer and use it in GitHub Desktop.
Save shreyasms17/e6b8984c4c20cfa54f5fb55810ba068e to your computer and use it in GitHub Desktop.
AutoFlatten Complex JSON
from pyspark.sql.functions import col, explode_outer
from pyspark.sql.types import *
from copy import deepcopy
from autoflatten import AutoFlatten
from collections import Counter
s3_path = 's3://mybucket/orders/'
df = spark.read.orc(s3_path)
json_df = spark.read.json(df.rdd.map(lambda row: row.json))
json_schema = json_df.schema
af = AutoFlatten(json_schema)
af.compute()
df1 = json_df
visited = set([f'.{column}' for column in df1.columns])
duplicate_target_counter = Counter(af.all_fields.values())
cols_to_select = df1.columns
for rest_col in af.rest:
if rest_col not in visited:
cols_to_select += [rest_col[1:]] if (duplicate_target_counter[af.all_fields[rest_col]]==1 and af.all_fields[rest_col] not in df1.columns) else [col(rest_col[1:]).alias(f"{rest_col[1:].replace('.', '>')}")]
visited.add(rest_col)
df1 = df1.select(cols_to_select)
if af.order:
for key in af.order:
column = key.split('.')[-1]
if af.bottom_to_top[key]:
#########
#values for the column in bottom_to_top dict exists if it is an array type
#########
df1 = df1.select('*', explode_outer(col(column)).alias(f"{column}_exploded")).drop(column)
data_type = df1.select(f"{column}_exploded").schema.fields[0].dataType
if not (isinstance(data_type, StructType) or isinstance(data_type, ArrayType)):
df1 = df1.withColumnRenamed(f"{column}_exploded", column if duplicate_target_counter[af.all_fields[key]]<=1 else key[1:].replace('.', '>'))
visited.add(key)
else:
#grabbing all paths to columns after explode
cols_in_array_col = set(map(lambda x: f'{key}.{x}', df1.select(f'{column}_exploded.*').columns))
#retrieving unvisited columns
cols_to_select_set = cols_in_array_col.difference(visited)
all_cols_to_select_set = set(af.bottom_to_top[key])
#check done for duplicate column name & path
cols_to_select_list = list(map(lambda x: f"{column}_exploded{'.'.join(x.split(key)[1:])}" if (duplicate_target_counter[af.all_fields[x]]<=1 and x.split('.')[-1] not in df1.columns) else col(f"{column}_exploded{'.'.join(x.split(key)[1:])}").alias(f"{x[1:].replace('.', '>')}"), list(all_cols_to_select_set)))
#updating visited set
visited.update(cols_to_select_set)
rem = list(map(lambda x: f"{column}_exploded{'.'.join(x.split(key)[1:])}", list(cols_to_select_set.difference(all_cols_to_select_set))))
df1 = df1.select(df1.columns + cols_to_select_list + rem).drop(f"{column}_exploded")
else:
#########
#values for the column in bottom_to_top dict do not exist if it is a struct type / array type containing a string type
#########
#grabbing all paths to columns after opening
cols_in_array_col = set(map(lambda x: f'{key}.{x}', df1.selectExpr(f'{column}.*').columns))
#retrieving unvisited columns
cols_to_select_set = cols_in_array_col.difference(visited)
#check done for duplicate column name & path
cols_to_select_list = list(map(lambda x: f"{column}.{x.split('.')[-1]}" if (duplicate_target_counter[x.split('.')[-1]]<=1 and x.split('.')[-1] not in df1.columns) else col(f"{column}.{x.split('.')[-1]}").alias(f"{x[1:].replace('.', '>')}"), list(cols_to_select_set)))
#updating visited set
visited.update(cols_to_select_set)
df1 = df1.select(df1.columns + cols_to_select_list).drop(f"{column}")
final_df = df1.select([field[1:].replace('.', '>') if duplicate_target_counter[af.all_fields[field]]>1 else af.all_fields[field] for field in af.all_fields])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment