Last active
November 18, 2015 13:48
-
-
Save alexnastetsky/581af2672328c4b8b023 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.types.SQLUserDefinedType | |
@SQLUserDefinedType(udt = classOf[ElementWithCountUDT]) | |
case class ElementWithCount(element:String, count:Int) extends Serializable { | |
override def toString: String = { | |
Seq( | |
element, | |
count | |
).mkString(" ") | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.types._ | |
class ElementWithCountUDT() extends UserDefinedType[ElementWithCount] { | |
override def sqlType: DataType = StructType(Seq( | |
StructField("element", StringType, nullable = false), //need this to be a specific type for storage | |
StructField("count", DataTypes.IntegerType, nullable = false) | |
)) | |
override def serialize(obj: Any): InternalRow = { | |
obj match { | |
case e: ElementWithCount => | |
val row = new VerveGenericMutableRow(2) | |
row.update(0, e.element) | |
row.setInt(1, e.count) | |
row | |
} | |
} | |
override def userClass: Class[ElementWithCount] = classOf[ElementWithCount] | |
override def deserialize(datum: Any): ElementWithCount = { | |
datum match { | |
case row: InternalRow => | |
require(row.numFields == 2, | |
s"ElementWithCountUDT.deserialize given row with length ${row.numFields} but requires length == 2") | |
val element = row.getString(0) | |
val count = row.getInt(1) | |
new ElementWithCount(element, count) | |
} | |
} | |
} | |
case object ElementWithCountUDT { | |
def apply(elementDataType : DataType) = new ElementWithCountUDT() | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.commons.logging.LogFactory | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types._ | |
object MergeArraysOfElementWithCountUDAF { | |
val logger = LogFactory.getLog(classOf[MergeArraysOfElementWithCountUDAF]) | |
} | |
class MergeArraysOfElementWithCountUDAF(elementDataType:DataType = StringType, isInputArray : Boolean = false) | |
extends UserDefinedAggregateFunction { | |
val udt = ElementWithCountUDT(StringType) | |
override def inputSchema: StructType = { | |
if (isInputArray) { | |
StructType(Seq(StructField("inputValue", DataTypes.createArrayType(elementDataType)))) | |
} else { | |
StructType(Seq(StructField("inputValue", StringType))) | |
} | |
} | |
override def bufferSchema: StructType = StructType(StructField("aggMap", DataTypes.createMapType(StringType, udt)) :: Nil) | |
override def dataType: DataType = DataTypes.createArrayType(udt) | |
override def deterministic: Boolean = true | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = Map[String,ElementWithCount]() | |
} | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
MergeArraysOfElementWithCountUDAF.logger.warn("update") | |
val map = buffer.getAs[Map[String,ElementWithCount]](0) | |
if (isInputArray) { | |
val inputArray = input.getAs[Seq[Any]](0) | |
inputArray.foreach(element => { | |
buffer(0) = updateMapWithElement(map, element.toString) | |
}) | |
} else { | |
val element = input.getAs[String](0) | |
buffer(0) = updateMapWithElement(map, element.toString) | |
} | |
} | |
def updateMapWithElement(map : Map[String,ElementWithCount], element : String) : Map[String,ElementWithCount] = { | |
val elementWithCountAndObserved = map.get(element) | |
if (elementWithCountAndObserved.isEmpty) { | |
map + (element -> new ElementWithCount(element, 1)) | |
} else { | |
val count = elementWithCountAndObserved.get.count + 1 | |
map + (element -> new ElementWithCount(element, count)) | |
} | |
} | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
val map1 = buffer1.getAs[Map[String,ElementWithCount]](0) | |
val map2 = buffer2.getAs[Map[String,ElementWithCount]](0) | |
map2.keys.foreach(element => { | |
val elementWithCountAndObserved1 = map1.get(element) | |
if (elementWithCountAndObserved1.isEmpty) { | |
buffer1(0) = map1 + (element -> map2.get(element).get) | |
} else { | |
val elementWithCountAndObserved2 = map2.get(element) | |
val count = elementWithCountAndObserved1.get.count + elementWithCountAndObserved2.get.count | |
buffer1(0) = map1 + (element -> new ElementWithCount(element, count)) | |
} | |
}) | |
} | |
override def evaluate(buffer: Row): Any = buffer.getAs[Map[Any,ElementWithCount]](0).values.toArray | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow | |
class VerveGenericMutableRow(values: Array[Any]) extends GenericMutableRow(values) { | |
def this(size: Int) = this(new Array[Any](size)) | |
// fixes java.lang.ClassCastException: java.lang.String cannot be cast to org.apache.spark.unsafe.types.UTF8String | |
// that results from the InternalRow.getString(ordinal) method | |
// this may be fixed in Spark 1.6.0, https://issues.apache.org/jira/browse/SPARK-9735 | |
override def getString(ordinal: Int): String = genericGet(ordinal).asInstanceOf[String] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment