Skip to content

Instantly share code, notes, and snippets.

@sbcd90
Last active July 14, 2017 00:57
Show Gist options
  • Save sbcd90/91063761a3950348cea6576d6f0ae3a0 to your computer and use it in GitHub Desktop.
Save sbcd90/91063761a3950348cea6576d6f0ae3a0 to your computer and use it in GitHub Desktop.
A spark app to show how user specific data types(UDTs) can be made generic using Byte array serialize/deserialize & UTF8String.
package org.apache.spark.sql
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@SQLUserDefinedType(udt = classOf[EmbeddedListUDT])
class EmbeddedList(val elements: Array[Any]) extends Serializable {
override def hashCode(): Int = {
var hashCode = 1
val i = elements.iterator
while (i.hasNext) {
val obj = i.next()
val elemValue = if (obj == null) 0 else obj.hashCode()
hashCode = 31 * hashCode + elemValue
}
hashCode
}
override def equals(other: scala.Any): Boolean = other match {
case that: EmbeddedList => that.elements.sameElements(this.elements)
case _ => false
}
override def toString: String = elements.mkString(", ")
}
class EmbeddedListUDT extends UserDefinedType[EmbeddedList] {
override def sqlType: DataType = ArrayType(StringType)
override def serialize(obj: EmbeddedList): Any = {
new GenericArrayData(obj.elements.map{elem =>
val out = new ByteArrayOutputStream()
val os = new ObjectOutputStream(out)
os.writeObject(elem)
UTF8String.fromBytes(out.toByteArray)
})
}
override def deserialize(datum: Any): EmbeddedList = {
datum match {
case values: ArrayData =>
new EmbeddedList(values.toArray[UTF8String](StringType).map{ elem =>
val in = new ByteArrayInputStream(elem.getBytes)
val is = new ObjectInputStream(in)
is.readObject()
})
case other => sys.error(s"Cannot deserialize $other")
}
}
override def userClass: Class[EmbeddedList] = classOf[EmbeddedList]
private[spark] override def asNullable = this
}
object EmbeddedListTestApp extends App {
val conf = new SparkConf().setAppName("TestApp29").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val schema = StructType(Array(StructField("id", new EmbeddedListUDT, false)))
val df = spark.sqlContext.createDataFrame(
spark.sparkContext.parallelize(List(Row(new EmbeddedList(Array(1, 2))),
Row(new EmbeddedList(Array(2, 3))))), schema)
df.show()
df.printSchema()
df.filter(row => {
row.getAs[EmbeddedList]("id").elements.apply(0) == 1
}).show()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment