Skip to content

Instantly share code, notes, and snippets.

@venkyvb
Forked from PawaritL/README.md
Created January 13, 2023 16:24
Show Gist options
  • Save venkyvb/ad5a8874c8ac22da909b440ae414409b to your computer and use it in GitHub Desktop.
Save venkyvb/ad5a8874c8ac22da909b440ae414409b to your computer and use it in GitHub Desktop.
Parse nested JSON into your ideal, customizable Spark schema (StructType)

Is Spark's JSON schema inference too inflexible for your liking?

Common Scenarios:

  • Automatic schema inference from Spark is not applying your desired type casting
  • You want to completely drop irrelevant fields when parsing
  • You want to avoid some highly nested fields simply by casting some outer fields as strings

Step 1: Provide your (ideal) JSON data example

REFERENCE_EXAMPLE = {
  "firstName": "Will",
  "lastName": "Ferrell",
  "physicalStats": {
    "height": 123,
    "weight": 456
  },
  "brothers": [
    {"name": "John", "age": 7, "awards": {"nominated": 69, "won": 420}},
    {"name": "C.", "age": 8, "awards": {"nominated": 420, "won": 69}},
    {"name": "Reilly", "age": 9, "awards": {"nominated": 69420, "won": 42069}}
  ]
}

Step 2: Generate the Spark DataFrame schema programmatically

REFERENCE_STRUCT = generate_schema(REFERENCE_EXAMPLE)

# --- Example Workflow
from pyspark import SparkContext
from pyspark.sql import SparkSession

sc = SparkContext.getOrCreate()
spark = SparkSession.builder.getOrCreate()

example_dataset = sc.parallelize([[json.dumps(REFERENCE_EXAMPLE)]] * 1000)
raw_df = spark.createDataFrame(example_dataset, schema=["body"])

# If you're happy with the generated schema
parsed_df = raw_df.select(F.from_json(F.col("body"), schema=REFERENCE_STRUCT).alias("parsedBody"))

# Alternatively, if you're unhappy with the generated schema
# e.g. if height should be floats rather than bigints
unhappy_string = REFERENCE_STRUCT.simpleString()
print(unhappy_string)

happy_string = unhappy_string.replace(
  "physicalStats:struct<height:bigint,weight:bigint>", 
  "physicalStats:struct<height:float,weight:float>"
  )
parsed_df = raw_df.select(F.from_json(F.col("body"), schema=happy_string).alias("parsedBody"))
import pyspark.sql.functions as F
from pyspark.sql.types import *
import json
from typing import Optional, List, Dict
TYPE_MAPPER = {
bool: BooleanType(),
str: StringType(),
int: LongType(),
float: DoubleType()
}
def generate_schema(input_json: Dict,
max_level: Optional[int] = None,
stringify_fields: Optional[List[str]] = None,
skip_fields: Optional[List[str]] = None
) -> StructType:
"""
User-friendly version for _populate_struct.
Given an input JSON (as a Python dictionary), returns the corresponding PySpark schema
:param input_json: example of the input JSON data (represented as a Python dictionary)
:param max_level: maximum levels of nested JSON to parse, beyond which values will be cast as strings
:param stringify_fields: list of fields to be directly cast as strings
:param skip_fields: list of field names to completely ignore parsing and omit from the schema
:return: pyspark.sql.types.StructType
"""
level = 1
return _populate_struct(input_json, level, max_level, stringify_fields, skip_fields)
def _populate_struct(input_json: Dict,
level: int = 1,
max_level: Optional[int] = None,
stringify_fields: Optional[List[str]] = None,
skip_fields: Optional[List[str]] = None
) -> StructType:
"""
Given an input JSON (as a Python dictionary), returns the corresponding PySpark StructType
:param input_json: example of the input JSON data (represented as a Python dictionary)
:param level: current level within the (nested) JSON. level=1 corresponds to the top level
:param max_level: maximum levels of nested JSON to parse, beyond which values will be cast as strings
:param stringify_fields: list of field names to be directly cast as strings
:param skip_fields: list of field names to completely ignore parsing and omit from the schema
:return: pyspark.sql.types.StructType
"""
if not isinstance(input_json, dict):
raise ValueError("invalid input JSON")
if not (isinstance(level, int) and (level > 0)):
raise ValueError("level must be greater than zero")
if max_level and not (isinstance(max_level, int) and (max_level >= level)):
raise ValueError("max_level must be greater than or equal to level (by default, level = 1)")
filled_struct = StructType()
nullable = True
for key in input_json.keys():
if skip_fields and (key in skip_fields):
continue
elif (stringify_fields and (key in stringify_fields)) or (max_level and (level >= max_level)):
filled_struct.add(StructField(key, StringType(), nullable))
elif isinstance(input_json[key], dict):
inner_level = level + 1
inner_struct = _populate_struct(input_json[key], inner_level, max_level, stringify_fields)
inner_field = StructField(key, inner_struct, nullable)
filled_struct.add(inner_field)
elif isinstance(input_json[key], list):
inner_level = level + 1
inner_array = _populate_array(input_json[key], inner_level)
inner_field = StructField(key, inner_array, nullable)
filled_struct.add(inner_field)
elif input_json[key] is not None:
inner_type = TYPE_MAPPER[type(input_json[key])]
inner_field = StructField(key, inner_type, nullable)
filled_struct.add(inner_field)
return filled_struct
def _populate_array(input_array: List,
level: int = 1
):
"""
Given an input Python list, returns the corresponding PySpark ArrayType
:param input_array: input array data (represented as a Python list)
:param level: current level within the (nested) JSON
:return: pyspark.sql.types.ArrayType
"""
if not isinstance(input_array, list):
raise ValueError("Invalid input array")
if not (isinstance(level, int) and (level > 0)):
raise ValueError("level must be greater than zero")
if len(input_array):
head = input_array[0]
inner_level = level + 1
if isinstance(head, list):
inner_array = _populate_array(head, inner_level)
filled_array = ArrayType(inner_array)
elif isinstance(head, dict):
inner_struct = _populate_struct(head, inner_level)
filled_array = ArrayType(inner_struct)
else:
inner_type = TYPE_MAPPER[type(head)]
filled_array = ArrayType(inner_type)
else:
default_type = StringType()
filled_array = ArrayType(default_type)
return filled_array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment