Skip to content

Instantly share code, notes, and snippets.

@zufri
Created July 1, 2016 03:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zufri/3a5d23afe8dd1c3952e17c8325b6f425 to your computer and use it in GitHub Desktop.
Save zufri/3a5d23afe8dd1c3952e17c8325b6f425 to your computer and use it in GitHub Desktop.
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{ParamMap, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.linalg.{VectorUDT, SparseVector, DenseVector, Vector}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
class VectorToArray(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
def this() =
this(Identifiable.randomUID("vectorToArray"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[VectorUDT],
s"The input column $inputColName must be a vector type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val outputFields = inputFields :+ StructField($(outputCol), ArrayType(DoubleType,false),true)
StructType(outputFields)
}
override def transform(dataset: DataFrame): DataFrame = {
val transformer = udf { vec: Vector =>
vec.toArray
}
val outputColName = $(outputCol)
dataset.select(col("*"),
transformer(dataset($(inputCol))).as(outputColName))
}
override def copy(extra: ParamMap): VectorToArray = {
defaultCopy(extra)
}
}
object VectorToArray extends DefaultParamsReadable[VectorToArray] {
override def load(path: String): VectorToArray = super.load(path)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment