Skip to content

Instantly share code, notes, and snippets.

@DevStarSJ
Created April 10, 2018 02:15
Show Gist options
  • Save DevStarSJ/be2786b5acd54a631d352e43d25c91d6 to your computer and use it in GitHub Desktop.
Save DevStarSJ/be2786b5acd54a631d352e43d25c91d6 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.Row
val df = sc.parallelize(Seq(
(1.0, 2.0), (0.0, -1.0),
(3.0, 4.0), (6.0, -2.3))).toDF("x", "y")
def transformRow(row: Row): Row = Row.fromSeq(row.toSeq ++ Array[Any](-1, 1))
def transformRows(iter: Iterator[Row]): Iterator[Row] = iter.map(transformRow)
val newSchema = StructType(df.schema.fields ++ Array(
StructField("z", IntegerType, false), StructField("v", IntegerType, false)))
df.rdd.mapPartitions(transformRows)
sqlContext.createDataFrame(df.rdd.mapPartitions(transformRows), newSchema).sho
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment