Skip to content

Instantly share code, notes, and snippets.

@jlln
Last active June 2, 2018 14:29
Show Gist options
  • Save jlln/51c5bb1a1d8ec2f6956d to your computer and use it in GitHub Desktop.
Save jlln/51c5bb1a1d8ec2f6956d to your computer and use it in GitHub Desktop.
One-hot encoder for use with Spark DataFrames.
import scala.collection.JavaConverters._
import org.apache.spark.sql.types.{StructType,StructField,StringType}
import org.apache.spark.sql.Row
def identityMatrix(n:Int):Array[Array[String]]=Array.tabulate(n,n)((x,y) => if(x==y) "1" else "0")
def encodeStringOneHot(table:org.apache.spark.sql.DataFrame,column:String) = {
//Accepts the dataframe and the target column name. Returns a new dataframe in which the target column has been replaced with a one-hot/dummy encoding.
table.registerTempTable("temp")
val categories_table = sqlContext.sql(s"SELECT DISTINCT $column FROM temp")
val n_categories:Int = categories_table.count().toInt
val categories = categories_table.collectAsList().asScala.map(_.get(0).toString).toList
val matrix = categories.zip(identityMatrix(n_categories)).map{
case (c,r) => c+:r
}
val matrix_rdd = sc.makeRDD(matrix)
val schema = StructType(StructField(column,StringType,true) +: categories.map(c=> StructField(c,StringType,true)))
val row_rdd:RDD[Row] = matrix_rdd.map(r => Row.fromSeq(r))
val table_oh = sqlContext.createDataFrame(row_rdd,schema)
var joined = table.join(table_oh,List(column),"left_outer").drop(column)
categories.map{
c=>{
joined = joined.withColumnRenamed(c,column+"_"+c)
}
}
joined
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment