Skip to content

Instantly share code, notes, and snippets.

@rainsunny
Last active October 19, 2017 07:39
Show Gist options
  • Save rainsunny/e03d0c66c1f00f05470e46052b844803 to your computer and use it in GitHub Desktop.
Save rainsunny/e03d0c66c1f00f05470e46052b844803 to your computer and use it in GitHub Desktop.
Derive multiple columns from a single column in a Spark DataFrame

UDF can return only a single column at the time. There are two different ways you can overcome this limitation:

Return a column of complex type.

The most general solution is a StructType but you can consider ArrayType or MapType as well.

import org.apache.spark.sql.functions.udf

val df = Seq(
  (1L, 3.0, "a"), (2L, -1.0, "b"), (3L, 0.0, "c")
).toDF("x", "y", "z")

case class Foobar(foo: Double, bar: Double)

val foobarUdf = udf((x: Long, y: Double, z: String) => 
  Foobar(x * y, z.head.toInt * y))

val df1 = df.withColumn("foobar", foobarUdf($"x", $"y", $"z"))
df1.show
// +---+----+---+------------+
// |  x|   y|  z|      foobar|
// +---+----+---+------------+
// |  1| 3.0|  a| [3.0,291.0]|
// |  2|-1.0|  b|[-2.0,-98.0]|
// |  3| 0.0|  c|   [0.0,0.0]|
// +---+----+---+------------+

df1.printSchema
// root
//  |-- x: long (nullable = false)
//  |-- y: double (nullable = false)
//  |-- z: string (nullable = true)
//  |-- foobar: struct (nullable = true)
//  |    |-- foo: double (nullable = false)
//  |    |-- bar: double (nullable = false)

This can be easily flattened later but usually there is no need for that.

Switch to RDD, reshape and rebuild DF:

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

def foobarFunc(x: Long, y: Double, z: String): Seq[Any] = 
  Seq(x * y, z.head.toInt * y)

val schema = StructType(df.schema.fields ++
  Array(StructField("foo", DoubleType), StructField("bar", DoubleType)))

val rows = df.rdd.map(r => Row.fromSeq(
  r.toSeq ++
  foobarFunc(r.getAs[Long]("x"), r.getAs[Double]("y"), r.getAs[String]("z"))))

val df2 = sqlContext.createDataFrame(rows, schema)

df2.show
// +---+----+---+----+-----+
// |  x|   y|  z| foo|  bar|
// +---+----+---+----+-----+
// |  1| 3.0|  a| 3.0|291.0|
// |  2|-1.0|  b|-2.0|-98.0|
// |  3| 0.0|  c| 0.0|  0.0|
// +---+----+---+----+-----+

Reference: https://stackoverflow.com/questions/32196207/derive-multiple-columns-from-a-single-column-in-a-spark-dataframe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment