Skip to content

Instantly share code, notes, and snippets.

@rainsunny
Created December 4, 2017 12:17
Show Gist options
  • Save rainsunny/34744dec9137e010cffd9f1905a1073b to your computer and use it in GitHub Desktop.
Save rainsunny/34744dec9137e010cffd9f1905a1073b to your computer and use it in GitHub Desktop.
Spark/Scala repeated calls to withColumn() using the same function on multiple columns [foldLeft]

Suppose you need to apply the same function to multiple columns in one DataFrame, one straight way is like this:

val newDF = oldDF.withColumn("colA", func("colA")).withColumn("colB", func("colB")).withColumn("colC", func("colC"))

If you want to save some type, you can try this:

  1. Use select with varargs including *:
import spark.implicits._

df.select($"*" +: Seq("A", "B", "C").map( c => func(c) ): _*)

Here:

  • Maps column names to func with Seq("A", ...).map(...)
  • Prepends all pre-existing columns with $"*" +: ...
  • Unpacks combined sequences with ... : _*

and can be generalized as:

import org.apache.spark.sql.{Column, DataFrame}

/**
* @param cols a sequence of columns to transform
* @param df an input DataFrame
* @param f a function to be applied on each col in cols
*/
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
 df.select($"*" +: cols.map(c => f(c)): _*)

Note: If you want to change the result column name, you can use column.as/alias(...); but generally you can not replace the original column (not like withColumn).

  1. With withColumn you can use foldLeft:
Seq("A","B","C").foldLeft(df)( (df, c) => df.withColumn( c, func(c) ) )

which can be generalized to :

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 * @param name a function mapping from input to output name.
 */
def withColumns(cols: Seq[String], df: DataFrame, 
    f: String =>  Column, name: String => String = identity) =
  cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))

Note here you can replace the original columns.

One example of func:

import org.apache.spark.sql._

def datefmt(c: String): Column = from_unixtime(col(c) / 1000, "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")

Another example:

// casting of all columns with idiomatic approach in scala
def castAllTypedColumnsTo(df: DataFrame, sourceType: DataType, targetType: DataType) = {
  df.schema.filter(_.dataType == sourceType).foldLeft(df) {
    case (acc, col) => acc.withColumn(col.name, df(col.name).cast(targetType))
  }
}

References:

  1. https://stackoverflow.com/questions/41400504/spark-scala-repeated-calls-to-withcolumn-using-the-same-function-on-multiple-c
  2. https://stackoverflow.com/questions/41997462/scala-spark-cast-multiple-columns-at-once
@adima
Copy link

adima commented Dec 24, 2018

Thanks a lot! That's just what I needed.

@kymtwyf
Copy link

kymtwyf commented Dec 24, 2020

👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment