Skip to content

Instantly share code, notes, and snippets.

@alev000
Last active August 15, 2019 22:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alev000/27d10a402ad250957b792091084932f4 to your computer and use it in GitHub Desktop.
Save alev000/27d10a402ad250957b792091084932f4 to your computer and use it in GitHub Desktop.
package com.example
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{StructField, StructType}
import scala.collection.mutable
import scala.reflect.runtime.universe.{TypeTag, typeTag}
case class CustomGroupBy[T, K: Encoder: TypeTag](data: Dataset[T], keyFunc: T => K) {
def agg(aggCalcs: GroupingCalculation[T, _, _]*): DataFrame = {
val keySchema = ScalaReflection.schemaFor[K]
val keyField = StructField("key", keySchema.dataType, keySchema.nullable)
val outputFields = aggCalcs.map { calc =>
StructField(calc.colName, calc.resultSchema.dataType, calc.resultSchema.nullable)
}
val combinedStruct = StructType(keyField +: outputFields)
println("DEBUGGING:") // TODO: REMOVE
println(s"type tag of key: ${typeTag[K]}")
println(s"keyField: ${keyField}")
println(s"outputFields: ${outputFields}")
println(s"combinedStruct: ${combinedStruct}")
val rowEnc: Encoder[Row] = RowEncoder(combinedStruct)
data
.groupByKey(keyFunc)
.mapGroups { (k: K, ts: Iterator[T]) =>
val state = mutable.ArrayBuffer[Any](aggCalcs.map { _.init() } : _*)
for (t <- ts) {
for ((calc, i) <- aggCalcs.zipWithIndex) {
state(i) = calc.doUpdate(t, state(i))
}
}
val results = mutable.ArrayBuffer[Any]()
results.append(k)
aggCalcs.zipWithIndex.foreach { x: (GroupingCalculation[T, _, _], Int) =>
val (calc, i) = x
val res = calc.doEvaluate(state(i))
results.append(res)
}
Row(results : _*)
}(rowEnc)
}
}
class GroupingCalculation[T, S, R: TypeTag: Encoder] private[example](
private[example] val colName: String,
private[example] val init: () => S,
private[example] val update: (T, S) => S,
private[example] val evaluate: S => R) extends Serializable {
def as(newName: String): GroupingCalculation[T, S, R] = {
new GroupingCalculation[T, S, R](newName, init, update, evaluate)
}
val resultSchema: ScalaReflection.Schema = ScalaReflection.schemaFor[R]
// private val resultRowConverter: InternalRow => Row = CatalystTypeConverters.createToScalaConverter(resultSchema.dataType).asInstanceOf[InternalRow => Row]
// private val resultExprEnc: ExpressionEncoder[R] = encoderFor[R].resolveAndBind()
def doUpdate(t: T, x: Any): S = update(t, x.asInstanceOf[S])
def doEvaluate(x: Any): R = {
val result: R = evaluate(x.asInstanceOf[S])
result
//val resultEncoded: Row = resultRowConverter(resultExprEnc.toRow(result))
//resultEncoded
}
}
object GroupingCalculation {
def apply[T, S, R: TypeTag: Encoder](
colName: String,
init: () => S,
update: (T, S) => S,
evaluate: S => R
): GroupingCalculation[T, S, R] = {
new GroupingCalculation(colName, init, update, evaluate)
}
def apply[T, S: TypeTag: Encoder](
colName: String,
init: () => S,
update: (T, S) => S,
merge: (S, S) => S
): GroupingCalculation[T, S, S] = {
apply(colName, init, update, evaluate = identity[S])
}
def firstAndOnly[T, S: TypeTag]
(colName: String, getter: T => S)
(implicit enc: Encoder[Option[S]]): GroupingCalculation[T, Option[S], Option[S]] = {
new GroupingCalculation[T, Option[S], Option[S]](
colName = colName,
init = () => Option.empty[S],
update = { (item: T, accum: Option[S]) =>
val current = getter(item)
accum match {
case None if current == null => None
case None => Some(current)
case Some(x) if x == current => accum
case _ => throw new RuntimeException(s"Mismatch found in firstAndOnly update for ${colName}: ${current} versus ${accum}")
}
},
evaluate = identity[Option[S]]
)
}
def fromCountsBy[T, G, R: TypeTag: Encoder]
(colName: String, byGetter: T => G)
(evaluate: Map[G, Long] => R): GroupingCalculation[T, Map[G, Long], R] = {
GroupingCalculation.apply[T, Map[G, Long], R](
colName = colName,
init = () => Map.empty[G, Long],
update = (item: T, accum: Map[G, Long]) => {
val by = byGetter(item)
accum + ((by, accum.getOrElse(by, 0L) + 1L))
},
evaluate = evaluate
)
}
def count[T](colName: String): GroupingCalculation[T, Long, Long] = {
GroupingCalculation[T, Long, Long](
colName,
init = () => 0L,
update = (_: T, accum: Long) => 1 + accum,
evaluate = identity[Long]
)(typeTag[Long], Encoders.scalaLong)
}
def countDistinct[T, S](s: String, getter: T => S): GroupingCalculation[T, Set[S], Long] = {
GroupingCalculation[T, Set[S], Long](
colName = s,
init = () => Set[S](),
update = (item: T, accum: Set[S]) => accum + getter(item),
evaluate = (accum: Set[S]) => accum.size
)(typeTag[Long], Encoders.scalaLong)
}
implicit class CustomGroupByHolder[T](data: Dataset[T]) {
def customGroupBy[K: Encoder: TypeTag](keyFunc: T => K): CustomGroupBy[T, K] = {
CustomGroupBy(data, keyFunc)
}
}
}
package com.example
import com.holdenkarau.spark.testing.DatasetSuiteBase
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.FunSuite
case class TestCase1(a: Int, b: String, c: Long, d: Double, e: Array[Int], f: Set[Int], g: Map[Long, String])
case class TestCase2(a: TestCase3, b: TestCase3)
case class TestCase3(a: Int, b: Double)
case class TestCase4(key: (Int, String), test: TestCase3)
class GroupingCalculationTest extends FunSuite with DatasetSuiteBase {
override def conf: SparkConf = {
new SparkConf().
setMaster("local[1]").
setAppName("test").
set("spark.ui.enabled", "false").
set("spark.app.id", appID).
set("spark.driver.host", "localhost")
}
import GroupingCalculation.CustomGroupByHolder
import spark.implicits._
lazy val data : Dataset[TestCase1] = Seq(
TestCase1(1, "b1", 1L, 11.0, Array(), Set(), Map(1L -> "x")),
TestCase1(1, "b1", 1L, 21.0, Array(0), Set(), Map(1L -> "x")),
TestCase1(1, "b1", 1L, 21.0, Array(0), Set(), Map(1L -> "x")),
TestCase1(2, "b2", 1L, 31.0, Array(1), Set(), Map(1L -> "x")),
TestCase1(2, "b3", 1L, 41.0, Array(1), Set(), Map(1L -> "x"))
).toDS
test("Test basic grouping calculations") {
val calcs: Seq[GroupingCalculation[TestCase1, _, _]] = Seq(
GroupingCalculation("sum",
init = () => 0,
update = (x: TestCase1, accum: Int) => x.a + accum,
merge = (a1: Int, a2: Int) => a1 + a2
),
GroupingCalculation("product",
init = () => 1.0,
update = (x: TestCase1, accum: Double) => x.d * accum,
merge = (a1: Double, a2: Double) => a1 * a2
),
GroupingCalculation("concat",
init = () => "",
update = (x: TestCase1, accum: String) => x.b + accum,
merge = (a1: String, a2: String) => a1 + a2
),
GroupingCalculation("count",
init = () => 0L,
update = (_: TestCase1, accum: Long) => 1 + accum,
merge = (a1: Long, a2: Long) => a1 + a2
)
)
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(calcs: _*)
.sort('key("_1"), 'key("_2"))
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
Row(Row(1, "b1"), 3L, 11.0 * 21.0 * 21.0, "b1b1b1", 3),
Row(Row(2, "b2"), 2L, 31.0, "b2", 1),
Row(Row(2, "b3"), 2L, 41.0, "b3", 1)
)
assert(result_df.columns === Seq("key", "sum", "product", "concat", "count"))
assert(result === expected)
}
test("Test collection-based grouping calculations") {
val calcs: Seq[GroupingCalculation[TestCase1, _, _]] = Seq(
GroupingCalculation("count_distinct_d_via_seq",
init = () => Seq[Double](),
update = (x: TestCase1, accum: Seq[Double]) => accum :+ x.d,
evaluate = (accum: Seq[Double]) => accum.toSet.size
),
GroupingCalculation("count_distinct_d_via_map",
init = () => Map[Double, Int](),
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, 1)),
evaluate = (accum: Map[Double, Int]) => accum.size
),
GroupingCalculation("count_duplicates_d",
init = () => Map[Double, Int](),
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, accum.getOrElse(x.d, 0) + 1)),
evaluate = (accum: Map[Double, Int]) => accum
),
GroupingCalculation("collect_set_d",
init = () => Map[Double, Int](),
update = (x: TestCase1, accum: Map[Double, Int]) => accum + ((x.d, 1)),
evaluate = (accum: Map[Double, Int]) => accum.keys.toSeq.sorted
),
GroupingCalculation("map_with_string",
init = () => Map[String, Int](),
update = (x: TestCase1, accum: Map[String, Int]) => {
accum ++ x.g.values.map(y => (y, 1))
},
evaluate = (accum: Map[String, Int]) => accum.size
)
)
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(calcs: _*)
.sort('key("_1"), 'key("_2"))
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
Row(Row(1, "b1"), 2, 2, Map(11.0 -> 1, 21.0 -> 2), List(11.0, 21.0), 1),
Row(Row(2, "b2"), 1, 1, Map(31.0 -> 1), List(31.0), 1),
Row(Row(2, "b3"), 1, 1, Map(41.0 -> 1), List(41.0), 1)
)
assert(result === expected)
}
test("Test grouping calculation with maps and options") {
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(
GroupingCalculation("map_with_option",
init = () => Map[Option[String], Int](),
update = (_: TestCase1, accum: Map[Option[String], Int]) => {
accum + (None -> 1)
},
evaluate = (accum: Map[Option[String], Int]) => accum.size
)
)
.sort('key("_1"), 'key("_2"))
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
Row(Row(1, "b1"), 1),
Row(Row(2, "b2"), 1),
Row(Row(2, "b3"), 1)
)
assert(result === expected)
}
test("Test grouping calculations with various combinations of case classes") {
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(
GroupingCalculation[TestCase1, TestCase2, TestCase3]("test",
init = () => TestCase2(TestCase3(0, 0.0), TestCase3(0, 0.0)),
update = (item: TestCase1, accum: TestCase2) => {
val TestCase2(TestCase3(x1, x2), TestCase3(x3, x4)) = accum
TestCase2(
TestCase3(1 + x1, item.d + x2),
TestCase3(3 + x3, -item.d + x4)
)
},
evaluate = (accum: TestCase2) => accum.a
)
)
.sort('key("_1"), 'key("_2"))
.as[TestCase4]
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
TestCase4((1, "b1"), TestCase3(3, 53.0)),
TestCase4((2, "b2"), TestCase3(1, 31.0)),
TestCase4((2, "b3"), TestCase3(1, 41.0))
)
assert(result === expected)
}
test("Test firstAndOnly") {
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(
GroupingCalculation.firstAndOnly[TestCase1, String]("B", _.b),
GroupingCalculation.firstAndOnly[TestCase1, Long]("C", _.c)
)
.sort('key("_1"), 'key("_2"))
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
Row(Row(1, "b1"), "b1", 1L),
Row(Row(2, "b2"), "b2", 1L),
Row(Row(2, "b3"), "b3", 1L)
)
assert(result === expected)
}
test("Test fromCountsBy") {
val result_df = data
.customGroupBy(x => (x.a, x.b))
.agg(
GroupingCalculation.fromCountsBy[TestCase1, Double, Long]("test1", _.d) { x: Map[Double, Long] => x.size },
GroupingCalculation.fromCountsBy[TestCase1, Double, Map[Double, Long]]("test2", _.d)(identity[Map[Double, Long]])
)
.sort('key("_1"), 'key("_2"))
println("DEBUGGING:") // TODO: REMOVE
result_df.explain()
val result = result_df.collect
val expected = Array(
Row(Row(1, "b1"), 2L, Map(11.0 -> 1L, 21.0 -> 2L)),
Row(Row(2, "b2"), 1L, Map(31.0 -> 1L)),
Row(Row(2, "b3"), 1L, Map(41.0 -> 1L))
)
assert(result === expected)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment