Last active
November 19, 2018 14:30
-
-
Save eric-maynard/9e857e8b4fc62886bd9b4c7025e606a0 to your computer and use it in GitHub Desktop.
Manipulating nested Spark DataFrames
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.cloudera.example | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.functions._ | |
import scala.collection.mutable._ | |
import scala.collection.mutable.ListBuffer | |
object Example { | |
implicit class NestedDataframe(df: DataFrame) { | |
//largely ported from StackOverflow: https://stackoverflow.com/a/45349745 | |
private def nullableCol(parentCol: Column, c: Column): Column = { | |
when(parentCol.isNotNull, c) | |
} | |
private def nullableCol(c: Column): Column = { | |
nullableCol(c, c) | |
} | |
private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = { | |
splitted.foldRight(newCol) { | |
case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName)) | |
} | |
} | |
private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = { | |
colType match { | |
case colType: StructType if splitted.nonEmpty => { | |
val modifiedFields: ListBuffer[(String, Column)] = ListBuffer(colType.fields.map(f => { | |
var curCol = col.getField(f.name) | |
if (f.name == splitted.head) { | |
curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol) | |
} | |
(f.name, curCol as f.name) | |
}): _*) | |
if (!modifiedFields.exists(_._1 == splitted.head)) { | |
modifiedFields.append((splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head)) | |
} | |
(struct(modifiedFields.map(_._2): _*), nullable) match { | |
case (struct, true) => struct | |
case (struct, false) => nullableCol(col, struct) | |
} | |
} | |
case _ => createNestedStructs(splitted, newCol) | |
} | |
} | |
def addNestedColumn(newCol: Column, newColName: String): DataFrame = { | |
if (newColName.contains('.')) { | |
var splitted = newColName.split('.') | |
val modifiedOrAdded: (String, Column) = df.schema.fields | |
.find(_.name == splitted.head) | |
.map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol))) | |
.getOrElse { | |
(splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head) | |
} | |
df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2) | |
} else { | |
// Top level addition, use spark method as-is | |
df.withColumn(newColName, newCol) | |
} | |
} | |
//largely ported from StackOverflow: https://stackoverflow.com/a/39943812 | |
def dropNested(colName: String): DataFrame = { | |
df.schema.fields | |
.flatMap(f => { | |
if (colName.startsWith(s"${f.name}.")) { | |
dropSubColumn(col(f.name), f.dataType, f.name, colName) match { | |
case Some(x) => Some((f.name, x)) | |
case None => None | |
} | |
} else { | |
None | |
} | |
}) | |
.foldLeft(df.drop(colName)) { | |
case (df, (colName, column)) => df.withColumn(colName, column) | |
} | |
} | |
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String, arrayDepth: Int = 0): Option[Column] = { | |
if (fullColName.equals(dropColName)) { | |
None | |
} else if (dropColName.startsWith(s"${fullColName}.")) { | |
colType match { | |
case colType: StructType => | |
Some(struct(colType.fields.flatMap(f => { | |
dropSubColumn(col.getField(f.name), f.dataType, s"${fullColName}.${f.name}", dropColName, arrayDepth) match { | |
case Some(x) => Some(x.alias(f.name)) | |
case None => None | |
} | |
}): _*)) | |
case colType: ArrayType => | |
colType.elementType match { | |
case innerType: StructType => | |
Some( | |
array( | |
struct(innerType.fields | |
.flatMap(f => | |
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName, arrayDepth + 1) match { | |
case Some(x) => Some(x.alias(f.name)) | |
case None => None | |
} ) | |
: _*))) | |
} | |
case other => Some(col) | |
} | |
} else { | |
var newType = colType | |
(1 to arrayDepth).foreach(i => newType = new ArrayType(newType, true)) | |
Some(col.cast(newType)) | |
} | |
} | |
} | |
def main(args: Array[String]): Unit = { | |
//set up SparkSession: | |
val spark = SparkSession.builder().master("local[2]").getOrCreate() | |
import spark.implicits._ | |
//create a nested DataFrame: | |
val stringDf = spark.createDataset((Seq("{\"a\": \"b\", \"foo\": {\"bar\": \"baz\", \"wing\": \"ding\"}}"))).toDF("string_data") | |
val originalSchema = new StructType().add($"a".string).add("foo", new StructType().add($"bar".string).add($"wing".string)) | |
val nestedDf = stringDf.select(from_json($"string_data", originalSchema)).toDF("json_data") | |
//modify the schema: | |
val droppedDf = nestedDf.dropNested("json_data.foo.wing") | |
val addedDf = nestedDf.addNestedColumn($"json_data.a", "x.y.z") | |
//original data: | |
nestedDf.printSchema() | |
// root | |
// |-- json_data: struct (nullable = true) | |
// | |-- a: string (nullable = true) | |
// | |-- foo: struct (nullable = true) | |
// | | |-- bar: string (nullable = true) | |
// | | |-- wing: string (nullable = true) | |
nestedDf.show() | |
// +----------------+ | |
// | json_data| | |
// +----------------+ | |
// |[b, [baz, ding]]| | |
// +----------------+ | |
//dropped data: | |
droppedDf.printSchema() | |
// root | |
// |-- json_data: struct (nullable = false) | |
// | |-- a: string (nullable = true) | |
// | |-- foo: struct (nullable = false) | |
// | | |-- bar: string (nullable = true) | |
droppedDf.show() | |
// +----------+ | |
// | json_data| | |
// +----------+ | |
// |[b, [baz]]| | |
// +----------+ | |
//added data: | |
addedDf.printSchema() | |
// root | |
// |-- json_data: struct (nullable = true) | |
// | |-- a: string (nullable = true) | |
// | |-- foo: struct (nullable = true) | |
// | | |-- bar: string (nullable = true) | |
// | | |-- wing: string (nullable = true) | |
// |-- x: struct (nullable = true) | |
// | |-- y: struct (nullable = true) | |
// | | |-- z: string (nullable = true) | |
addedDf.show() | |
// +----------------+-----+ | |
// | json_data| x| | |
// +----------------+-----+ | |
// |[b, [baz, ding]]|[[b]]| | |
// +----------------+-----+ | |
// now supports arrays: | |
val stringArrayDf = spark.createDataset((Seq("{\"a\": \"b\", \"foo\": [{\"bar\": [\"baz\"], \"wing\": \"1\"}, {\"bar\": [\"car\"], \"wing\": \"2\"}]}"))).toDF("string_data") | |
val originalArraySchema = new StructType().add($"a".string).add("foo", new ArrayType(new StructType().add("bar", new ArrayType(StringType, true)).add($"wing".string), true)) | |
val nestedArrayDf = stringArrayDf.select(from_json($"string_data", originalArraySchema)).toDF("json_data") | |
nestedArrayDf.printSchema | |
// root | |
// |-- json_data: struct (nullable = true) | |
// | |-- a: string (nullable = true) | |
// | |-- foo: array (nullable = true) | |
// | | |-- element: struct (containsNull = true) | |
// | | | |-- bar: array (nullable = true) | |
// | | | | |-- element: string (containsNull = true) | |
// | | | |-- wing: integer (nullable = true) | |
nestedArrayDf.dropNested("json_data.foo.bar", Map("json_data.foo" -> 2)).printSchema | |
// root | |
// |-- json_data: struct (nullable = false) | |
// | |-- a: string (nullable = true) | |
// | |-- foo: array (nullable = false) | |
// | | |-- element: struct (containsNull = false) | |
// | | | |-- wing: string (nullable = true) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment