Skip to content

Instantly share code, notes, and snippets.

@eric-maynard
Last active November 19, 2018 14:30
Show Gist options
  • Save eric-maynard/9e857e8b4fc62886bd9b4c7025e606a0 to your computer and use it in GitHub Desktop.
Save eric-maynard/9e857e8b4fc62886bd9b4c7025e606a0 to your computer and use it in GitHub Desktop.
Manipulating nested Spark DataFrames
package com.cloudera.example
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import scala.collection.mutable._
import scala.collection.mutable.ListBuffer
object Example {
implicit class NestedDataframe(df: DataFrame) {
//largely ported from StackOverflow: https://stackoverflow.com/a/45349745
private def nullableCol(parentCol: Column, c: Column): Column = {
when(parentCol.isNotNull, c)
}
private def nullableCol(c: Column): Column = {
nullableCol(c, c)
}
private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = {
splitted.foldRight(newCol) {
case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName))
}
}
private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = {
colType match {
case colType: StructType if splitted.nonEmpty => {
val modifiedFields: ListBuffer[(String, Column)] = ListBuffer(colType.fields.map(f => {
var curCol = col.getField(f.name)
if (f.name == splitted.head) {
curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol)
}
(f.name, curCol as f.name)
}): _*)
if (!modifiedFields.exists(_._1 == splitted.head)) {
modifiedFields.append((splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head))
}
(struct(modifiedFields.map(_._2): _*), nullable) match {
case (struct, true) => struct
case (struct, false) => nullableCol(col, struct)
}
}
case _ => createNestedStructs(splitted, newCol)
}
}
def addNestedColumn(newCol: Column, newColName: String): DataFrame = {
if (newColName.contains('.')) {
var splitted = newColName.split('.')
val modifiedOrAdded: (String, Column) = df.schema.fields
.find(_.name == splitted.head)
.map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol)))
.getOrElse {
(splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head)
}
df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2)
} else {
// Top level addition, use spark method as-is
df.withColumn(newColName, newCol)
}
}
//largely ported from StackOverflow: https://stackoverflow.com/a/39943812
def dropNested(colName: String): DataFrame = {
df.schema.fields
.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
}
} else {
None
}
})
.foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
}
}
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String, arrayDepth: Int = 0): Option[Column] = {
if (fullColName.equals(dropColName)) {
None
} else if (dropColName.startsWith(s"${fullColName}.")) {
colType match {
case colType: StructType =>
Some(struct(colType.fields.flatMap(f => {
dropSubColumn(col.getField(f.name), f.dataType, s"${fullColName}.${f.name}", dropColName, arrayDepth) match {
case Some(x) => Some(x.alias(f.name))
case None => None
}
}): _*))
case colType: ArrayType =>
colType.elementType match {
case innerType: StructType =>
Some(
array(
struct(innerType.fields
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName, arrayDepth + 1) match {
case Some(x) => Some(x.alias(f.name))
case None => None
} )
: _*)))
}
case other => Some(col)
}
} else {
var newType = colType
(1 to arrayDepth).foreach(i => newType = new ArrayType(newType, true))
Some(col.cast(newType))
}
}
}
def main(args: Array[String]): Unit = {
//set up SparkSession:
val spark = SparkSession.builder().master("local[2]").getOrCreate()
import spark.implicits._
//create a nested DataFrame:
val stringDf = spark.createDataset((Seq("{\"a\": \"b\", \"foo\": {\"bar\": \"baz\", \"wing\": \"ding\"}}"))).toDF("string_data")
val originalSchema = new StructType().add($"a".string).add("foo", new StructType().add($"bar".string).add($"wing".string))
val nestedDf = stringDf.select(from_json($"string_data", originalSchema)).toDF("json_data")
//modify the schema:
val droppedDf = nestedDf.dropNested("json_data.foo.wing")
val addedDf = nestedDf.addNestedColumn($"json_data.a", "x.y.z")
//original data:
nestedDf.printSchema()
// root
// |-- json_data: struct (nullable = true)
// | |-- a: string (nullable = true)
// | |-- foo: struct (nullable = true)
// | | |-- bar: string (nullable = true)
// | | |-- wing: string (nullable = true)
nestedDf.show()
// +----------------+
// | json_data|
// +----------------+
// |[b, [baz, ding]]|
// +----------------+
//dropped data:
droppedDf.printSchema()
// root
// |-- json_data: struct (nullable = false)
// | |-- a: string (nullable = true)
// | |-- foo: struct (nullable = false)
// | | |-- bar: string (nullable = true)
droppedDf.show()
// +----------+
// | json_data|
// +----------+
// |[b, [baz]]|
// +----------+
//added data:
addedDf.printSchema()
// root
// |-- json_data: struct (nullable = true)
// | |-- a: string (nullable = true)
// | |-- foo: struct (nullable = true)
// | | |-- bar: string (nullable = true)
// | | |-- wing: string (nullable = true)
// |-- x: struct (nullable = true)
// | |-- y: struct (nullable = true)
// | | |-- z: string (nullable = true)
addedDf.show()
// +----------------+-----+
// | json_data| x|
// +----------------+-----+
// |[b, [baz, ding]]|[[b]]|
// +----------------+-----+
// now supports arrays:
val stringArrayDf = spark.createDataset((Seq("{\"a\": \"b\", \"foo\": [{\"bar\": [\"baz\"], \"wing\": \"1\"}, {\"bar\": [\"car\"], \"wing\": \"2\"}]}"))).toDF("string_data")
val originalArraySchema = new StructType().add($"a".string).add("foo", new ArrayType(new StructType().add("bar", new ArrayType(StringType, true)).add($"wing".string), true))
val nestedArrayDf = stringArrayDf.select(from_json($"string_data", originalArraySchema)).toDF("json_data")
nestedArrayDf.printSchema
// root
// |-- json_data: struct (nullable = true)
// | |-- a: string (nullable = true)
// | |-- foo: array (nullable = true)
// | | |-- element: struct (containsNull = true)
// | | | |-- bar: array (nullable = true)
// | | | | |-- element: string (containsNull = true)
// | | | |-- wing: integer (nullable = true)
nestedArrayDf.dropNested("json_data.foo.bar", Map("json_data.foo" -> 2)).printSchema
// root
// |-- json_data: struct (nullable = false)
// | |-- a: string (nullable = true)
// | |-- foo: array (nullable = false)
// | | |-- element: struct (containsNull = false)
// | | | |-- wing: string (nullable = true)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment