Skip to content

Instantly share code, notes, and snippets.

@alexnastetsky
Last active November 18, 2015 13:48
Show Gist options
  • Save alexnastetsky/581af2672328c4b8b023 to your computer and use it in GitHub Desktop.
Save alexnastetsky/581af2672328c4b8b023 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.types.SQLUserDefinedType
@SQLUserDefinedType(udt = classOf[ElementWithCountUDT])
case class ElementWithCount(element:String, count:Int) extends Serializable {
override def toString: String = {
Seq(
element,
count
).mkString(" ")
}
}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
class ElementWithCountUDT() extends UserDefinedType[ElementWithCount] {
override def sqlType: DataType = StructType(Seq(
StructField("element", StringType, nullable = false), //need this to be a specific type for storage
StructField("count", DataTypes.IntegerType, nullable = false)
))
override def serialize(obj: Any): InternalRow = {
obj match {
case e: ElementWithCount =>
val row = new VerveGenericMutableRow(2)
row.update(0, e.element)
row.setInt(1, e.count)
row
}
}
override def userClass: Class[ElementWithCount] = classOf[ElementWithCount]
override def deserialize(datum: Any): ElementWithCount = {
datum match {
case row: InternalRow =>
require(row.numFields == 2,
s"ElementWithCountUDT.deserialize given row with length ${row.numFields} but requires length == 2")
val element = row.getString(0)
val count = row.getInt(1)
new ElementWithCount(element, count)
}
}
}
case object ElementWithCountUDT {
def apply(elementDataType : DataType) = new ElementWithCountUDT()
}
import org.apache.commons.logging.LogFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object MergeArraysOfElementWithCountUDAF {
val logger = LogFactory.getLog(classOf[MergeArraysOfElementWithCountUDAF])
}
class MergeArraysOfElementWithCountUDAF(elementDataType:DataType = StringType, isInputArray : Boolean = false)
extends UserDefinedAggregateFunction {
val udt = ElementWithCountUDT(StringType)
override def inputSchema: StructType = {
if (isInputArray) {
StructType(Seq(StructField("inputValue", DataTypes.createArrayType(elementDataType))))
} else {
StructType(Seq(StructField("inputValue", StringType)))
}
}
override def bufferSchema: StructType = StructType(StructField("aggMap", DataTypes.createMapType(StringType, udt)) :: Nil)
override def dataType: DataType = DataTypes.createArrayType(udt)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String,ElementWithCount]()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
MergeArraysOfElementWithCountUDAF.logger.warn("update")
val map = buffer.getAs[Map[String,ElementWithCount]](0)
if (isInputArray) {
val inputArray = input.getAs[Seq[Any]](0)
inputArray.foreach(element => {
buffer(0) = updateMapWithElement(map, element.toString)
})
} else {
val element = input.getAs[String](0)
buffer(0) = updateMapWithElement(map, element.toString)
}
}
def updateMapWithElement(map : Map[String,ElementWithCount], element : String) : Map[String,ElementWithCount] = {
val elementWithCountAndObserved = map.get(element)
if (elementWithCountAndObserved.isEmpty) {
map + (element -> new ElementWithCount(element, 1))
} else {
val count = elementWithCountAndObserved.get.count + 1
map + (element -> new ElementWithCount(element, count))
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String,ElementWithCount]](0)
val map2 = buffer2.getAs[Map[String,ElementWithCount]](0)
map2.keys.foreach(element => {
val elementWithCountAndObserved1 = map1.get(element)
if (elementWithCountAndObserved1.isEmpty) {
buffer1(0) = map1 + (element -> map2.get(element).get)
} else {
val elementWithCountAndObserved2 = map2.get(element)
val count = elementWithCountAndObserved1.get.count + elementWithCountAndObserved2.get.count
buffer1(0) = map1 + (element -> new ElementWithCount(element, count))
}
})
}
override def evaluate(buffer: Row): Any = buffer.getAs[Map[Any,ElementWithCount]](0).values.toArray
}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
class VerveGenericMutableRow(values: Array[Any]) extends GenericMutableRow(values) {
def this(size: Int) = this(new Array[Any](size))
// fixes java.lang.ClassCastException: java.lang.String cannot be cast to org.apache.spark.unsafe.types.UTF8String
// that results from the InternalRow.getString(ordinal) method
// this may be fixed in Spark 1.6.0, https://issues.apache.org/jira/browse/SPARK-9735
override def getString(ordinal: Int): String = genericGet(ordinal).asInstanceOf[String]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment